diff --git a/lang/src/son.rs b/lang/src/son.rs index f6ac48c..e34b951 100644 --- a/lang/src/son.rs +++ b/lang/src/son.rs @@ -3800,7 +3800,7 @@ impl<'a> Function<'a> { block.push(self.rg(ph)); } self.blocks[self.nodes[nid].ralloc_backref as usize].params = block; - self.reschedule_block(&mut node.outputs); + self.reschedule_block(nid, &mut node.outputs); for o in node.outputs.into_iter().rev() { self.emit_node(o, nid); } @@ -3874,7 +3874,7 @@ impl<'a> Function<'a> { )]); } - self.reschedule_block(&mut node.outputs); + self.reschedule_block(nid, &mut node.outputs); for o in node.outputs.into_iter().rev() { self.emit_node(o, nid); } @@ -3882,7 +3882,7 @@ impl<'a> Function<'a> { Kind::Then | Kind::Else => { self.nodes[nid].ralloc_backref = self.add_block(nid); self.bridge(prev, nid); - self.reschedule_block(&mut node.outputs); + self.reschedule_block(nid, &mut node.outputs); for o in node.outputs.into_iter().rev() { self.emit_node(o, nid); } @@ -3983,7 +3983,7 @@ impl<'a> Function<'a> { self.add_instr(nid, ops); - self.reschedule_block(&mut node.outputs); + self.reschedule_block(nid, &mut node.outputs); for o in node.outputs.into_iter().rev() { if self.nodes[o].inputs[0] == nid || (matches!(self.nodes[o].kind, Kind::Loop | Kind::Region) @@ -4073,7 +4073,8 @@ impl<'a> Function<'a> { .push(regalloc2::Block::new(self.nodes[pred].ralloc_backref as usize)); } - fn reschedule_block(&mut self, outputs: &mut Vc) { + fn reschedule_block(&mut self, from: Nid, outputs: &mut Vc) { + let from = Some(&from); let mut buf = Vec::with_capacity(outputs.len()); let mut seen = BitSet::default(); seen.clear(self.nodes.values.len()); @@ -4089,11 +4090,11 @@ impl<'a> Function<'a> { buf.push(o); while let Some(&n) = buf.get(cursor) { for &i in &self.nodes[n].inputs[1..] { - if self.nodes[n].inputs.first() == self.nodes[i].inputs.first() - && self.nodes[i].outputs.iter().all(|&o| { - self.nodes[o].inputs.first() != self.nodes[i].inputs.first() - || seen.get(o) - }) + if from == self.nodes[i].inputs.first() + && self.nodes[i] + .outputs + .iter() + .all(|&o| self.nodes[o].inputs.first() != from || seen.get(o)) && seen.set(i) { buf.push(i); @@ -4111,11 +4112,11 @@ impl<'a> Function<'a> { buf.push(o); while let Some(&n) = buf.get(cursor) { for &i in &self.nodes[n].inputs[1..] { - if self.nodes[n].inputs.first() == self.nodes[i].inputs.first() - && self.nodes[i].outputs.iter().all(|&o| { - self.nodes[o].inputs.first() != self.nodes[i].inputs.first() - || seen.get(o) - }) + if from == self.nodes[i].inputs.first() + && self.nodes[i] + .outputs + .iter() + .all(|&o| self.nodes[o].inputs.first() != from || seen.get(o)) && seen.set(i) { buf.push(i); @@ -4125,7 +4126,12 @@ impl<'a> Function<'a> { } } - debug_assert!(outputs.len() == buf.len() || outputs.len() == buf.len() + 1,); + debug_assert!( + outputs.len() == buf.len() || outputs.len() == buf.len() + 1, + "{:?} {:?}", + outputs, + buf + ); if buf.len() + 1 == outputs.len() { outputs.remove(outputs.len() - 1);