fixing overoptimization of load -> store

This commit is contained in:
Jakub Doka 2024-10-24 15:39:38 +02:00
parent 127e8dcb38
commit cb88edea1f
No known key found for this signature in database
GPG key ID: C6E9A89936B8C143
6 changed files with 84 additions and 70 deletions

View file

@ -544,10 +544,23 @@ main := fn(): int {
```hb ```hb
Color := struct {r: u8, g: u8, b: u8, a: u8} Color := struct {r: u8, g: u8, b: u8, a: u8}
white := Color.(255, 255, 255, 255) white := Color.(255, 255, 255, 255)
u32_to_color := fn(v: u32): Color return @bitcast(v) u32_to_color := fn(v: u32): Color return @as(Color, @bitcast(u32_to_u32(@bitcast(v))))
u32_to_u32 := fn(v: u32): u32 return v u32_to_u32 := fn(v: u32): u32 return v
main := fn(): int { main := fn(): int {
return u32_to_color(@bitcast(white)).r + @as(Color, @bitcast(u32_to_u32(@bitcast(Color.{r: 1, g: 1, b: 1, a: 1})))).g return u32_to_color(@bitcast(white)).r
}
```
#### small_struct_assignment
```hb
Color := struct {r: u8, g: u8, b: u8, a: u8}
white := Color.(255, 255, 255, 255)
black := Color.(0, 0, 0, 0)
main := fn(): int {
f := black
f = white
return f.a
} }
``` ```

View file

@ -39,16 +39,6 @@ type Nid = u16;
type Lookup = crate::ctx_map::CtxMap<Nid>; type Lookup = crate::ctx_map::CtxMap<Nid>;
trait StoreId: Sized {
fn to_store(self) -> Option<Self>;
}
impl StoreId for Nid {
fn to_store(self) -> Option<Self> {
(self != MEM).then_some(self)
}
}
impl crate::ctx_map::CtxEntry for Nid { impl crate::ctx_map::CtxEntry for Nid {
type Ctx = [Result<Node, (Nid, debug::Trace)>]; type Ctx = [Result<Node, (Nid, debug::Trace)>];
type Key<'a> = (Kind, &'a [Nid], ty::Id); type Key<'a> = (Kind, &'a [Nid], ty::Id);
@ -457,6 +447,7 @@ impl Nodes {
K::Stre => { K::Stre => {
if self[target].inputs[2] != VOID if self[target].inputs[2] != VOID
&& self[target].inputs.len() == 4 && self[target].inputs.len() == 4
&& self[self[target].inputs[1]].kind != Kind::Load
&& self[self[target].inputs[3]].kind == Kind::Stre && self[self[target].inputs[3]].kind == Kind::Stre
&& self[self[target].inputs[3]].lock_rc == 0 && self[self[target].inputs[3]].lock_rc == 0
&& self[self[target].inputs[3]].inputs[2] == self[target].inputs[2] && self[self[target].inputs[3]].inputs[2] == self[target].inputs[2]
@ -872,7 +863,7 @@ impl Nodes {
#[allow(dead_code)] #[allow(dead_code)]
fn eliminate_stack_temporaries(&mut self) { fn eliminate_stack_temporaries(&mut self) {
'o: for stack in self[MEM].outputs.clone() { 'o: for stack in self[MEM].outputs.clone() {
if self[stack].kind != Kind::Stck { if self.values[stack as usize].is_err() || self[stack].kind != Kind::Stck {
continue; continue;
} }
let mut full_read_into = None; let mut full_read_into = None;
@ -1032,7 +1023,7 @@ pub enum Kind {
impl Kind { impl Kind {
fn is_pinned(&self) -> bool { fn is_pinned(&self) -> bool {
self.is_cfg() || matches!(self, Self::Phi | Self::Mem | Self::Arg) self.is_cfg() || matches!(self, Self::Phi | Self::Arg | Self::Mem)
} }
fn is_cfg(&self) -> bool { fn is_cfg(&self) -> bool {
@ -1629,7 +1620,7 @@ impl ItemCtx {
if !matches!(self.nodes[stck].kind, Kind::Stck | Kind::Arg) { if !matches!(self.nodes[stck].kind, Kind::Stck | Kind::Arg) {
debug_assert_matches!( debug_assert_matches!(
self.nodes[stck].kind, self.nodes[stck].kind,
Kind::Phi | Kind::Return | Kind::Load Kind::Phi | Kind::Return | Kind::Load | Kind::Call { .. } | Kind::Stre
); );
continue; continue;
} }
@ -2045,11 +2036,8 @@ impl<'a> Codegen<'a> {
); );
debug_assert!(self.ci.nodes[region].kind != Kind::Stre); debug_assert!(self.ci.nodes[region].kind != Kind::Stre);
let mut vc = Vc::from([VOID, value, region]);
self.ci.nodes.load_loop_store(&mut self.ci.scope.store, &mut self.ci.loops); self.ci.nodes.load_loop_store(&mut self.ci.scope.store, &mut self.ci.loops);
if let Some(str) = self.ci.scope.store.value().to_store() { let mut vc = Vc::from([VOID, value, region, self.ci.scope.store.value()]);
vc.push(str);
}
for load in self.ci.scope.loads.drain(..) { for load in self.ci.scope.loads.drain(..) {
if load == value { if load == value {
self.ci.nodes.unlock(load); self.ci.nodes.unlock(load);
@ -2080,11 +2068,8 @@ impl<'a> Codegen<'a> {
self.ty_display(self.ci.nodes[region].ty) self.ty_display(self.ci.nodes[region].ty)
); );
debug_assert!(self.ci.nodes[region].kind != Kind::Stre); debug_assert!(self.ci.nodes[region].kind != Kind::Stre);
let mut vc = Vc::from([VOID, region]);
self.ci.nodes.load_loop_store(&mut self.ci.scope.store, &mut self.ci.loops); self.ci.nodes.load_loop_store(&mut self.ci.scope.store, &mut self.ci.loops);
if let Some(str) = self.ci.scope.store.value().to_store() { let vc = [VOID, region, self.ci.scope.store.value()];
vc.push(str);
}
let load = self.ci.nodes.new_node(ty, Kind::Load, vc); let load = self.ci.nodes.new_node(ty, Kind::Load, vc);
self.ci.nodes.lock(load); self.ci.nodes.lock(load);
self.ci.scope.loads.push(load); self.ci.scope.loads.push(load);
@ -2206,9 +2191,7 @@ impl<'a> Codegen<'a> {
debug_assert_ne!(self.ci.ctrl, VOID); debug_assert_ne!(self.ci.ctrl, VOID);
let mut inps = Vc::from([self.ci.ctrl, value.id]); let mut inps = Vc::from([self.ci.ctrl, value.id]);
self.ci.nodes.load_loop_store(&mut self.ci.scope.store, &mut self.ci.loops); self.ci.nodes.load_loop_store(&mut self.ci.scope.store, &mut self.ci.loops);
if let Some(str) = self.ci.scope.store.value().to_store() { inps.push(self.ci.scope.store.value());
inps.push(str);
}
self.ci.ctrl = self.ci.nodes.new_node(ty::Id::VOID, Kind::Return, inps); self.ci.ctrl = self.ci.nodes.new_node(ty::Id::VOID, Kind::Return, inps);
@ -2570,9 +2553,7 @@ impl<'a> Codegen<'a> {
self.ci.nodes.unlock(n); self.ci.nodes.unlock(n);
} }
if let Some(str) = self.ci.scope.store.value().to_store() { inps.push(self.ci.scope.store.value());
inps.push(str);
}
self.ci.scope.loads.retain(|&load| { self.ci.scope.loads.retain(|&load| {
if inps.contains(&load) { if inps.contains(&load) {
return true; return true;
@ -2658,9 +2639,7 @@ impl<'a> Codegen<'a> {
} }
if has_ptr_arg { if has_ptr_arg {
if let Some(str) = self.ci.scope.store.value().to_store() { inps.push(self.ci.scope.store.value());
inps.push(str);
}
self.ci.scope.loads.retain(|&load| { self.ci.scope.loads.retain(|&load| {
if inps.contains(&load) { if inps.contains(&load) {
return true; return true;
@ -4243,7 +4222,7 @@ fn push_up(nodes: &mut Nodes) {
return; return;
} }
for i in 0..nodes[node].inputs.len() { for i in 1..nodes[node].inputs.len() {
let inp = nodes[node].inputs[i]; let inp = nodes[node].inputs[i];
if !nodes[inp].kind.is_pinned() { if !nodes[inp].kind.is_pinned() {
push_up_impl(inp, nodes); push_up_impl(inp, nodes);
@ -4298,6 +4277,20 @@ fn push_up(nodes: &mut Nodes) {
} }
} }
} }
debug_assert_eq!(
nodes
.iter()
.map(|(n, _)| n)
.filter(|&n| !nodes.visited.get(n) && !matches!(nodes[n].kind, Kind::Arg | Kind::Mem))
.collect::<Vec<_>>(),
vec![],
"{:?}",
nodes
.iter()
.filter(|&(n, nod)| !nodes.visited.get(n) && !matches!(nod.kind, Kind::Arg | Kind::Mem))
.collect::<Vec<_>>()
);
} }
fn push_down(nodes: &mut Nodes, node: Nid) { fn push_down(nodes: &mut Nodes, node: Nid) {
@ -4312,6 +4305,7 @@ fn push_down(nodes: &mut Nodes, node: Nid) {
} }
fn better(nodes: &mut Nodes, is: Nid, then: Nid) -> bool { fn better(nodes: &mut Nodes, is: Nid, then: Nid) -> bool {
debug_assert_ne!(idepth(nodes, is), idepth(nodes, then), "{is} {then}");
loop_depth(is, nodes) < loop_depth(then, nodes) loop_depth(is, nodes) < loop_depth(then, nodes)
|| idepth(nodes, is) > idepth(nodes, then) || idepth(nodes, is) > idepth(nodes, then)
|| nodes[then].kind == Kind::If || nodes[then].kind == Kind::If
@ -4321,6 +4315,12 @@ fn push_down(nodes: &mut Nodes, node: Nid) {
return; return;
} }
for usage in nodes[node].outputs.clone() {
if is_forward_edge(usage, node, nodes) && nodes[node].kind == Kind::Stre {
push_down(nodes, usage);
}
}
for usage in nodes[node].outputs.clone() { for usage in nodes[node].outputs.clone() {
if is_forward_edge(usage, node, nodes) { if is_forward_edge(usage, node, nodes) {
push_down(nodes, usage); push_down(nodes, usage);
@ -4342,14 +4342,11 @@ fn push_down(nodes: &mut Nodes, node: Nid) {
debug_assert!(nodes.dominates(nodes[node].inputs[0], min)); debug_assert!(nodes.dominates(nodes[node].inputs[0], min));
let mut cursor = min; let mut cursor = min;
loop { while cursor != nodes[node].inputs[0] {
cursor = idom(nodes, cursor);
if better(nodes, cursor, min) { if better(nodes, cursor, min) {
min = cursor; min = cursor;
} }
if cursor == nodes[node].inputs[0] {
break;
}
cursor = idom(nodes, cursor);
} }
if nodes[min].kind.ends_basic_block() { if nodes[min].kind.ends_basic_block() {
@ -4468,6 +4465,7 @@ mod tests {
// Purely Testing Examples; // Purely Testing Examples;
returning_global_struct; returning_global_struct;
small_struct_bitcast; small_struct_bitcast;
small_struct_assignment;
wide_ret; wide_ret;
comptime_min_reg_leak; comptime_min_reg_leak;
different_types; different_types;

View file

@ -287,7 +287,6 @@ impl BitSet {
self.data.resize(new_len, 0); self.data.resize(new_len, 0);
} }
#[track_caller]
pub fn set(&mut self, idx: Nid) -> bool { pub fn set(&mut self, idx: Nid) -> bool {
let idx = idx as usize; let idx = idx as usize;
let data_idx = idx / Self::ELEM_SIZE; let data_idx = idx / Self::ELEM_SIZE;
@ -296,4 +295,12 @@ impl BitSet {
self.data[data_idx] |= 1 << sub_idx; self.data[data_idx] |= 1 << sub_idx;
prev == 0 prev == 0
} }
pub fn get(&self, idx: Nid) -> bool {
let idx = idx as usize;
let data_idx = idx / Self::ELEM_SIZE;
let sub_idx = idx % Self::ELEM_SIZE;
let prev = self.data[data_idx] & (1 << sub_idx);
prev != 0
}
} }

View file

@ -1,17 +1,17 @@
main: main:
LI64 r5, 0d LI64 r5, 0d
LRA r4, r0, :gb LRA r4, r0, :gb
LD r6, r4, 0a, 8h LI64 r10, 6d
CMPS r9, r6, r5 LD r7, r4, 0a, 8h
CMPUI r9, r9, 0d CMPS r11, r7, r5
ORI r11, r9, 0d CMPUI r11, r11, 0d
LI64 r1, 6d ORI r12, r11, 0d
ANDI r11, r11, 255d ANDI r12, r12, 255d
JNE r11, r0, :0 JNE r12, r0, :0
CP r5, r1 CP r5, r10
JMP :1 JMP :1
0: LI64 r5, 1d 0: LI64 r5, 1d
1: SUB64 r1, r5, r1 1: SUB64 r1, r5, r10
JALA r0, r31, 0a JALA r0, r31, 0a
code size: 131 code size: 131
ret: 0 ret: 0

View file

@ -1,33 +1,29 @@
main: main:
ADDI64 r254, r254, -24d ADDI64 r254, r254, -12d
ST r31, r254, 8a, 16h ST r31, r254, 4a, 8h
LRA r1, r0, :white LRA r1, r0, :white
LD r32, r1, 0a, 4h ADDI64 r4, r254, 0d
ADDI64 r5, r254, 4d LD r2, r1, 0a, 4h
CP r2, r32
JAL r31, r0, :u32_to_color JAL r31, r0, :u32_to_color
ST r1, r254, 4a, 4h
CP r2, r32
JAL r31, r0, :u32_to_u32
ADDI64 r12, r254, 0d
LD r11, r254, 4a, 1h
ST r1, r254, 0a, 4h ST r1, r254, 0a, 4h
LD r4, r254, 1a, 1h LD r9, r254, 0a, 1h
ADD8 r7, r4, r11 ANDI r1, r9, 255d
ANDI r1, r7, 255d LD r31, r254, 4a, 8h
LD r31, r254, 8a, 16h ADDI64 r254, r254, 12d
ADDI64 r254, r254, 24d
JALA r0, r31, 0a JALA r0, r31, 0a
u32_to_color: u32_to_color:
ADDI64 r254, r254, -4d ADDI64 r254, r254, -12d
ADDI64 r3, r254, 0d ST r31, r254, 4a, 8h
ST r2, r254, 0a, 4h JAL r31, r0, :u32_to_u32
LD r1, r3, 0a, 4h ADDI64 r5, r254, 0d
ADDI64 r254, r254, 4d ST r1, r254, 0a, 4h
LD r1, r5, 0a, 4h
LD r31, r254, 4a, 8h
ADDI64 r254, r254, 12d
JALA r0, r31, 0a JALA r0, r31, 0a
u32_to_u32: u32_to_u32:
CP r1, r2 CP r1, r2
JALA r0, r31, 0a JALA r0, r31, 0a
code size: 284 code size: 263
ret: 254 ret: 255
status: Ok(()) status: Ok(())