#![feature(if_let_guard)]
#![feature(slice_take)]
use {
    cranelift_codegen::{
        CodegenError, Final, FinalizedMachReloc, MachBufferFinalized,
        ir::{InstBuilder, UserExternalName},
        isa::LookupError,
        settings::Configurable,
    },
    cranelift_frontend::FunctionBuilder,
    cranelift_module::{Module, ModuleError},
    hblang::{
        nodes::Kind,
        utils::{Ent, EntVec},
    },
    std::{fmt::Display, ops::Range},
};

mod x86_64;

pub struct Backend {
    ctx: cranelift_codegen::Context,
    dt_ctx: cranelift_module::DataDescription,
    fb_ctx: cranelift_frontend::FunctionBuilderContext,
    module: Option<cranelift_object::ObjectModule>,
    ctrl_plane: cranelift_codegen::control::ControlPlane,
    funcs: Functions,
    globals: EntVec<hblang::ty::Global, Global>,
    asm: Assembler,
}

impl Backend {
    pub fn new(triple: target_lexicon::Triple) -> Result<Self, BackendCreationError> {
        Ok(Self {
            ctx: cranelift_codegen::Context::new(),
            dt_ctx: cranelift_module::DataDescription::new(),
            fb_ctx: cranelift_frontend::FunctionBuilderContext::default(),
            ctrl_plane: cranelift_codegen::control::ControlPlane::default(),
            module: cranelift_object::ObjectModule::new(cranelift_object::ObjectBuilder::new(
                cranelift_codegen::isa::lookup(triple)?.finish(
                    cranelift_codegen::settings::Flags::new({
                        let mut bl = cranelift_codegen::settings::builder();
                        bl.set("enable_verifier", "true").unwrap();
                        bl
                    }),
                )?,
                "main",
                cranelift_module::default_libcall_names(),
            )?)
            .into(),
            funcs: Default::default(),
            globals: Default::default(),
            asm: Default::default(),
        })
    }
}

impl hblang::backend::Backend for Backend {
    fn assemble_reachable(
        &mut self,
        from: hblang::ty::Func,
        types: &hblang::ty::Types,
        files: &hblang::utils::EntSlice<hblang::ty::Module, hblang::parser::Ast>,
        to: &mut Vec<u8>,
    ) -> hblang::backend::AssemblySpec {
        debug_assert!(self.asm.frontier.is_empty());
        debug_assert!(self.asm.funcs.is_empty());
        debug_assert!(self.asm.globals.is_empty());

        let mut module = self.module.take().expect("backend can assemble only once");

        fn clif_name_to_ty(name: UserExternalName) -> hblang::ty::Id {
            match name.namespace {
                0 => hblang::ty::Kind::Func(hblang::ty::Func::new(name.index as _)),
                1 => hblang::ty::Kind::Global(hblang::ty::Global::new(name.index as _)),
                _ => unreachable!(),
            }
            .compress()
        }

        self.globals.shadow(types.ins.globals.len());

        self.asm.frontier.push(from.into());
        while let Some(itm) = self.asm.frontier.pop() {
            match itm.expand() {
                hblang::ty::Kind::Func(func) => {
                    let fuc = &mut self.funcs.headers[func];
                    self.asm.funcs.push(func);
                    self.asm.frontier.extend(
                        fuc.external_names.clone().map(|r| {
                            clif_name_to_ty(self.funcs.external_names[r as usize].clone())
                        }),
                    );
                    self.asm.name.clear();
                    if func == from {
                        self.asm.name.push_str("main");
                    } else {
                        let file = &files[types.ins.funcs[func].file];
                        self.asm.name.push_str(&file.path);
                        self.asm.name.push('.');
                        self.asm.name.push_str(file.ident_str(types.ins.funcs[func].name));
                    }
                    let linkage = if func == from {
                        cranelift_module::Linkage::Export
                    } else {
                        cranelift_module::Linkage::Local
                    };
                    build_signature(
                        module.isa().default_call_conv(),
                        types.ins.funcs[func].sig,
                        types,
                        &mut self.ctx.func.signature,
                        &mut vec![],
                    );
                    fuc.module_id = Some(
                        module
                            .declare_function(&self.asm.name, linkage, &self.ctx.func.signature)
                            .unwrap(),
                    );
                }
                hblang::ty::Kind::Global(glob) => {
                    self.asm.globals.push(glob);
                    self.asm.name.clear();
                    let file = &files[types.ins.globals[glob].file];
                    self.asm.name.push_str(&file.path);
                    self.asm.name.push('.');
                    self.asm.name.push_str(file.ident_str(types.ins.globals[glob].name));
                    self.globals[glob].module_id = Some(
                        module
                            .declare_data(
                                &self.asm.name,
                                cranelift_module::Linkage::Local,
                                true,
                                false,
                            )
                            .unwrap(),
                    );
                }
                _ => unreachable!(),
            }
        }

        for &func in &self.asm.funcs {
            let fuc = &self.funcs.headers[func];
            debug_assert!(!fuc.code.is_empty());
            let names = &mut self.funcs.external_names
                [fuc.external_names.start as usize..fuc.external_names.end as usize];
            names.iter_mut().for_each(|nm| {
                nm.index = fuc.module_id.unwrap().as_u32();
                self.ctx.func.params.ensure_user_func_name(nm.clone());
            });
            module
                .define_function_bytes(
                    fuc.module_id.unwrap(),
                    &self.ctx.func,
                    fuc.alignment as _,
                    &self.funcs.code[fuc.code.start as usize..fuc.code.end as usize],
                    &self.funcs.relocs[fuc.relocs.start as usize..fuc.relocs.end as usize],
                )
                .unwrap();
        }

        for global in self.asm.globals.drain(..) {
            let glob = &self.globals[global];
            self.dt_ctx.clear();
            self.dt_ctx.define(types.ins.globals[global].data.clone().into());
            module.define_data(glob.module_id.unwrap(), &self.dt_ctx).unwrap();
        }

        module.finish().object.write_stream(to).unwrap();

        hblang::backend::AssemblySpec { code_length: 0, data_length: 0, entry: 0 }
    }

    fn disasm<'a>(
        &'a self,
        _sluce: &[u8],
        _eca_handler: &mut dyn FnMut(&mut &[u8]),
        _types: &'a hblang::ty::Types,
        _files: &'a hblang::utils::EntSlice<hblang::ty::Module, hblang::parser::Ast>,
        _output: &mut String,
    ) -> Result<(), std::boxed::Box<dyn core::error::Error + Send + Sync + 'a>> {
        unimplemented!()
    }

    fn emit_body(
        &mut self,
        id: hblang::ty::Func,
        nodes: &hblang::nodes::Nodes,
        tys: &hblang::ty::Types,
        files: &hblang::utils::EntSlice<hblang::ty::Module, hblang::parser::Ast>,
    ) {
        self.ctx.clear();

        let mut lens = vec![];
        let stack_ret = build_signature(
            self.module.as_ref().unwrap().isa().default_call_conv(),
            tys.ins.funcs[id].sig,
            tys,
            &mut self.ctx.func.signature,
            &mut lens,
        );

        FuncBuilder {
            bl: FunctionBuilder::new(&mut self.ctx.func, &mut self.fb_ctx),
            nodes,
            tys,
            files,
            values: &mut vec![None; nodes.len()],
        }
        .build(tys.ins.funcs[id].sig, &lens, stack_ret);

        self.ctx.compile(self.module.as_ref().unwrap().isa(), &mut self.ctrl_plane).unwrap();
        let code = self.ctx.compiled_code().unwrap();
        self.funcs.push(id, &self.ctx.func, &code.buffer);
    }
}

fn build_signature(
    call_conv: cranelift_codegen::isa::CallConv,
    sig: hblang::ty::Sig,
    types: &hblang::ty::Types,
    signature: &mut cranelift_codegen::ir::Signature,
    arg_lens: &mut Vec<usize>,
) -> bool {
    signature.clear(call_conv);
    match call_conv {
        cranelift_codegen::isa::CallConv::SystemV => {
            x86_64::build_systemv_signature(sig, types, signature, arg_lens)
        }
        _ => todo!(),
    }
}

struct FuncBuilder<'a, 'b> {
    bl: cranelift_frontend::FunctionBuilder<'b>,
    nodes: &'a hblang::nodes::Nodes,
    tys: &'a hblang::ty::Types,
    #[expect(unused)]
    files: &'a hblang::utils::EntSlice<hblang::ty::Module, hblang::parser::Ast>,
    values: &'b mut [Option<Result<cranelift_codegen::ir::Value, cranelift_codegen::ir::Block>>],
}

impl FuncBuilder<'_, '_> {
    pub fn build(mut self, sig: hblang::ty::Sig, arg_lens: &[usize], stack_ret: bool) {
        let entry = self.bl.create_block();
        self.bl.append_block_params_for_function_params(entry);
        self.bl.switch_to_block(entry);
        let mut arg_vals = self.bl.block_params(entry);

        if stack_ret {
            let ret_ptr = *arg_vals.take_first().unwrap();
            self.values[hblang::nodes::MEM as usize] = Some(Ok(ret_ptr));
        }

        let Self { nodes, tys, .. } = self;

        let mut parama_len = arg_lens.iter();
        let mut typs = sig.args.args();
        let mut args = nodes[hblang::nodes::VOID].outputs[hblang::nodes::ARG_START..].iter();
        while let Some(aty) = typs.next(tys) {
            let hblang::ty::Arg::Value(ty) = aty else { continue };
            let loc = arg_vals.take(..*parama_len.next().unwrap()).unwrap();
            let &arg = args.next().unwrap();
            if ty.is_aggregate(tys) {
                todo!()
            } else {
                debug_assert_eq!(loc.len(), 0);
                self.values[arg as usize] = Some(Ok(loc[0]));
            }
        }

        self.values[hblang::nodes::ENTRY as usize] = Some(Err(entry));

        self.emit_node(hblang::nodes::VOID, hblang::nodes::VOID);

        self.bl.finalize();
    }

    fn value_of(&self, nid: hblang::nodes::Nid) -> cranelift_codegen::ir::Value {
        self.values[nid as usize].unwrap().unwrap()
    }

    fn block_of(&self, nid: hblang::nodes::Nid) -> cranelift_codegen::ir::Block {
        self.values[nid as usize].unwrap().unwrap_err()
    }

    fn close_block(&mut self, nid: hblang::nodes::Nid) {
        if matches!(self.nodes[nid].kind, Kind::Loop | Kind::Region) {
            return;
        }
        self.bl.seal_block(self.block_of(nid));
    }

    fn emit_node(&mut self, nid: hblang::nodes::Nid, block: hblang::nodes::Nid) {
        use hblang::nodes::*;

        let mut args = vec![];
        if matches!(self.nodes[nid].kind, Kind::Region | Kind::Loop) {
            let side = 1 + self.values[nid as usize].is_some() as usize;
            for &o in self.nodes[nid].outputs.iter() {
                if self.nodes[o].is_data_phi() {
                    args.push(self.value_of(self.nodes[0].inputs[side]));
                }
            }
            match (self.nodes[nid].kind, self.values[nid as usize]) {
                (Kind::Loop, Some(blck)) => {
                    self.bl.ins().jump(blck.unwrap_err(), &args);
                    self.bl.seal_block(blck.unwrap_err());
                    return;
                }
                (Kind::Region, None) => {
                    let next = self.bl.create_block();
                    for &o in self.nodes[nid].outputs.iter() {
                        if self.nodes[o].is_data_phi() {
                            self.values[o as usize] = Some(Ok(self.bl.append_block_param(
                                next,
                                ty_to_clif_ty(self.nodes[o].ty, self.tys),
                            )));
                        }
                    }
                    self.bl.ins().jump(next, &args);
                    self.bl.seal_block(next);
                    self.values[nid as usize] = Some(Err(next));
                    return;
                }
                _ => {}
            }
        }

        let node = &self.nodes[nid];
        self.values[nid as usize] = Some(match node.kind {
            Kind::Start => {
                debug_assert_eq!(self.nodes[node.outputs[0]].kind, Kind::Entry);
                self.emit_node(node.outputs[0], block);
                return;
            }
            Kind::If => {
                let &[_, cnd] = node.inputs.as_slice() else { unreachable!() };
                let &[then, else_] = node.outputs.as_slice() else { unreachable!() };

                let then_bl = self.bl.create_block();
                let else_bl = self.bl.create_block();
                let c = self.value_of(cnd);
                self.bl.ins().brif(c, then_bl, &[], else_bl, &[]);
                self.values[then as usize] = Some(Err(then_bl));
                self.values[else_ as usize] = Some(Err(else_bl));

                self.close_block(block);
                self.bl.switch_to_block(then_bl);
                self.emit_node(then, then);
                self.bl.switch_to_block(else_bl);
                self.emit_node(else_, else_);
                Err(self.block_of(block))
            }
            Kind::Region | Kind::Loop => {
                if node.kind == Kind::Loop {
                    let next = self.bl.create_block();
                    for &o in self.nodes[nid].outputs.iter() {
                        if self.nodes[o].is_data_phi() {
                            self.values[o as usize] = Some(Ok(self.bl.append_block_param(
                                next,
                                ty_to_clif_ty(self.nodes[o].ty, self.tys),
                            )));
                        }
                    }
                    self.values[nid as usize] = Some(Err(next));
                }
                self.bl.ins().jump(self.values[nid as usize].unwrap().unwrap_err(), &args);
                self.close_block(block);
                self.bl.switch_to_block(self.values[nid as usize].unwrap().unwrap_err());
                for &o in node.outputs.iter().rev() {
                    self.emit_node(o, nid);
                }
                Err(self.block_of(block))
            }
            Kind::Return { .. } | Kind::Die => {
                let ret = self.value_of(node.inputs[1]);
                self.bl.ins().return_(&[ret]);
                self.close_block(block);
                self.emit_node(node.outputs[0], block);
                Err(self.block_of(block))
            }
            Kind::Entry => {
                for &o in node.outputs.iter().rev() {
                    self.emit_node(o, nid);
                }
                return;
            }
            Kind::Then | Kind::Else => {
                for &o in node.outputs.iter().rev() {
                    self.emit_node(o, block);
                }
                Err(self.block_of(block))
            }
            Kind::Call { func: _, unreachable, .. } => {
                if unreachable {
                    todo!()
                } else {
                    todo!();
                    //for &o in node.outputs.iter().rev() {
                    //    if self.nodes[o].inputs[0] == nid
                    //        || (matches!(self.nodes[o].kind, Kind::Loop | Kind::Region)
                    //            && self.nodes[o].inputs[1] == nid)
                    //    {
                    //        self.emit_node(o, block);
                    //    }
                    //}
                }
            }
            Kind::CInt { value } if self.nodes[nid].ty.is_integer() => Ok(self.bl.ins().iconst(
                cranelift_codegen::ir::Type::int(self.tys.size_of(self.nodes[nid].ty) as u16 * 8)
                    .unwrap(),
                value,
            )),
            Kind::CInt { value } => Ok(match self.tys.size_of(self.nodes[nid].ty) {
                4 => self.bl.ins().f32const(f64::from_bits(value as _) as f32),
                8 => self.bl.ins().f64const(f64::from_bits(value as _)),
                _ => unimplemented!(),
            }),
            Kind::BinOp { .. }
            | Kind::UnOp { .. }
            | Kind::Global { .. }
            | Kind::Load { .. }
            | Kind::Stre
            | Kind::RetVal
            | Kind::Stck => todo!(),
            Kind::End | Kind::Phi | Kind::Arg | Kind::Mem | Kind::Loops | Kind::Join => return,
            Kind::Assert { .. } => unreachable!(),
        });
    }
}

fn ty_to_clif_ty(ty: hblang::ty::Id, tys: &hblang::ty::Types) -> cranelift_codegen::ir::Type {
    if ty.is_integer() {
        cranelift_codegen::ir::Type::int(tys.size_of(ty) as u16 * 8).unwrap()
    } else {
        unimplemented!()
    }
}

#[derive(Default)]
struct Global {
    module_id: Option<cranelift_module::DataId>,
}

#[derive(Default)]
struct FuncHeaders {
    module_id: Option<cranelift_module::FuncId>,
    alignment: u32,
    code: Range<u32>,
    relocs: Range<u32>,
    external_names: Range<u32>,
}

#[derive(Default)]
struct Functions {
    headers: EntVec<hblang::ty::Func, FuncHeaders>,
    code: Vec<u8>,
    relocs: Vec<FinalizedMachReloc>,
    external_names: Vec<UserExternalName>,
}

impl Functions {
    fn push(
        &mut self,
        id: hblang::ty::Func,
        func: &cranelift_codegen::ir::Function,
        code: &MachBufferFinalized<Final>,
    ) {
        self.headers.shadow(id.index() + 1);
        self.headers[id] = FuncHeaders {
            module_id: None,
            alignment: code.alignment,
            code: self.code.len() as u32..self.code.len() as u32 + code.data().len() as u32,
            relocs: self.relocs.len() as u32..self.relocs.len() as u32 + code.relocs().len() as u32,
            external_names: self.external_names.len() as u32
                ..self.external_names.len() as u32 + func.params.user_named_funcs().len() as u32,
        };
        self.code.extend(code.data());
        self.relocs.extend(code.relocs().iter().cloned());
        self.external_names.extend(func.params.user_named_funcs().values().cloned());
    }
}

#[derive(Default)]
struct Assembler {
    name: String,
    frontier: Vec<hblang::ty::Id>,
    globals: Vec<hblang::ty::Global>,
    funcs: Vec<hblang::ty::Func>,
}

#[derive(Debug)]
pub enum BackendCreationError {
    UnsupportedTriplet(LookupError),
    InvalidFlags(CodegenError),
    UnsupportedModuleConfig(ModuleError),
}

impl Display for BackendCreationError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        match self {
            BackendCreationError::UnsupportedTriplet(err) => {
                write!(f, "Unsupported triplet: {}", err)
            }
            BackendCreationError::InvalidFlags(err) => {
                write!(f, "Invalid flags: {}", err)
            }
            BackendCreationError::UnsupportedModuleConfig(err) => {
                write!(f, "Unsupported module configuration: {}", err)
            }
        }
    }
}
impl core::error::Error for BackendCreationError {}

impl From<LookupError> for BackendCreationError {
    fn from(value: LookupError) -> Self {
        Self::UnsupportedTriplet(value)
    }
}

impl From<CodegenError> for BackendCreationError {
    fn from(value: CodegenError) -> Self {
        Self::InvalidFlags(value)
    }
}

impl From<ModuleError> for BackendCreationError {
    fn from(value: ModuleError) -> Self {
        Self::UnsupportedModuleConfig(value)
    }
}