From 4bdc602ef98fe7f6c5c01cc24761d9c08294ac6e Mon Sep 17 00:00:00 2001
From: Graham Kelly <gkgoat6700@gmail.com>
Date: Sun, 23 Jun 2024 13:31:20 -0400
Subject: [PATCH] better fusing

---
 src/passes/mem_fusing.rs | 62 +++++++++++++++++++++++++++++++++-------
 1 file changed, 51 insertions(+), 11 deletions(-)

diff --git a/src/passes/mem_fusing.rs b/src/passes/mem_fusing.rs
index b54ec01..3955508 100644
--- a/src/passes/mem_fusing.rs
+++ b/src/passes/mem_fusing.rs
@@ -1,4 +1,4 @@
-use std::{collections::BTreeMap, convert::Infallible, iter::empty};
+use std::{collections::BTreeMap, convert::Infallible, iter::{empty, once}};
 
 use anyhow::Context;
 // use libc::name_t;
@@ -55,7 +55,7 @@ impl Fuse {
             target: d,
         });
     }
-    pub fn finalize(self, m: &mut Module) -> Memory{
+    pub fn finalize(self, m: &mut Module) -> Memory {
         let mem = m.memories[self.target].clone();
         m.memories = EntityVec::default();
         let new = m.memories.push(mem);
@@ -110,7 +110,7 @@ impl Fuse {
         // }
         return new;
     }
-    pub fn process(&self, f: &mut FunctionBody) {
+    pub fn process(&self, m: &mut Module, f: &mut FunctionBody) {
         let vz = f.arg_pool.from_iter(empty());
         let tz = f.type_pool.from_iter(empty());
         let ti = f.type_pool.from_iter(vec![Type::I32].into_iter());
@@ -123,6 +123,42 @@ impl Fuse {
                 // let vi = v;
                 if let ValueDef::Operator(a, b, c) = &mut w {
                     let mut bp = f.arg_pool[*b].to_vec();
+                    fn g(
+                        a: impl for<'a> FnMut(&mut Module,&mut FunctionBody, Memory, &'a mut crate::Value),
+                    ) -> impl for<'a> FnMut(&mut Module,&mut FunctionBody, Memory, &'a mut crate::Value)
+                    {
+                        return a;
+                    }
+                    let mut p = g(|m: &mut Module,f, mem, v| {
+                        match (m.memories[mem].memory64, m.memories[self.target].memory64) {
+                            (true, true) => {}
+                            (true, false) => {
+                                let ti = f.type_pool.from_iter(once(Type::I32));
+                                let w = f.arg_pool.from_iter(vec![*v].into_iter());
+                                let x = f.add_value(ValueDef::Operator(
+                                    Operator::I32WrapI64,
+                                    w,
+                                    ti,
+                                ));
+                                f.append_to_block(k, x);
+                                // crate::append_before(f, x, vi, k);
+                                *v = x;
+                            }
+                            (false, true) => {
+                                let ti = f.type_pool.from_iter(once(Type::I64));
+                                let w = f.arg_pool.from_iter(vec![*v].into_iter());
+                                let x = f.add_value(ValueDef::Operator(
+                                    Operator::I64ExtendI32U,
+                                    w,
+                                    ti,
+                                ));
+                                f.append_to_block(k, x);
+                                // crate::append_before(f, x, vi, k);
+                                *v = x;
+                            },
+                            (false, false) => {}
+                        }
+                    });
                     match a.clone() {
                         Operator::MemorySize { mem } => {
                             if mem != self.target {
@@ -139,6 +175,7 @@ impl Fuse {
                                     function_index: self.size,
                                 };
                                 bp.push(ia);
+                                p(m,f, mem, &mut bp[0]);
                             }
                         }
                         Operator::MemoryGrow { mem } => {
@@ -156,20 +193,22 @@ impl Fuse {
                                     function_index: self.grow,
                                 };
                                 bp.push(ia);
+                                p(m, f,mem, &mut bp[0]);
                             }
                         }
-                        _ => crate::op_traits::rewrite_mem(a, &mut bp, |m, v| {
-                            if *m != self.target{
+                        _ => crate::op_traits::rewrite_mem(a, &mut bp, |mem, v| {
+                            if *mem != self.target {
                                 let ia = f.add_value(ValueDef::Operator(
                                     Operator::I32Const {
-                                        value: m.index() as u32,
+                                        value: mem.index() as u32,
                                     },
                                     vz,
                                     ti,
                                 ));
-                                f.append_to_block(k, ia); 
+                                f.append_to_block(k, ia);
                                 // append_before(f, ia, vi, k);
                                 if let Some(v) = v {
+                                    p(m,f, *mem, &mut *v);
                                     let w = f.arg_pool.from_iter(vec![*v, ia].into_iter());
                                     let x = f.add_value(ValueDef::Operator(
                                         Operator::Call {
@@ -182,10 +221,11 @@ impl Fuse {
                                     // crate::append_before(f, x, vi, k);
                                     *v = x;
                                 }
-                                *m = self.target;
+                                *mem = self.target;
                             }
-                            Ok::<(),Infallible>(())
-                        }).unwrap(),
+                            Ok::<(), Infallible>(())
+                        })
+                        .unwrap(),
                     }
                     *b = *ka
                         .entry(bp.clone())
@@ -201,7 +241,7 @@ pub fn fuse(m: &mut Module) -> anyhow::Result<()> {
     let f = Fuse::new(m).context("in getting the fuse funcs")?;
     crate::passes::unmem::metafuse_all(m, &mut crate::passes::unmem::All {});
     // crate::passes::splice::splice_module(m)?;
-    m.per_func_body(|b| f.process(b));
+    m.take_per_func_body(|m, b| f.process(m, b));
     f.finalize(m);
     return Ok(());
 }