use core::num;

use {
    crate::memory::{MemoryManager, PhysicalAddress, VirtualAddress},
    alloc::boxed::Box,
    spin::{Mutex, Once},
};

use super::PAGE_SIZE;

pub enum PageSize {
    Size4KiB,
    Size2MiB,
    Size1GiB,
    // FIXME: SV48 support
    // Size512GiB,
    // FIXME: SV57 support
    // Size256TiB,
}

impl PageSize {
    fn level(&self) -> usize {
        match self {
            PageSize::Size4KiB => 0,
            PageSize::Size2MiB => 1,
            PageSize::Size1GiB => 2,
            // FIXME: SV48 and SV57 support
        }
    }
}

pub struct PageTable {
    entries: [PageEntry; 512],
}

impl PageTable {
    /// Walk the page table to convert a virtual address to a physical address.
    /// If a page fault would occur, this returns None. Otherwise, it returns the physical address.
    pub fn virt_to_phys(&self, vaddr: VirtualAddress) -> Option<PhysicalAddress> {
        let vpn = vaddr.vpns();

        let mut v = &self.entries[vpn[2]];
        for i in (0..=2).rev() {
            if v.is_invalid() {
                // This is an invalid entry, page fault.
                break;
            } else if v.is_leaf() {
                // In RISC-V, a leaf can be at any level.

                // The offset mask masks off the PPN. Each PPN is 9 bits and they start at bit #12.
                // So, our formula 12 + i * 9
                let off_mask = (1 << (12 + i * 9)) - 1;
                let vaddr_pgoff = vaddr.as_addr() & off_mask;
                let addr = ((v.entry() << 2) as usize) & !off_mask;
                return Some((addr | vaddr_pgoff).into());
            }
            // Set v to the next entry which is pointed to by this entry.
            // However, the address was shifted right by 2 places when stored in the page table
            // entry, so we shift it left to get it back into place.
            let entry = v.addr().as_ptr::<PageEntry>();
            // We do i - 1 here, however we should get None or Some() above
            // before we do 0 - 1 = -1.
            v = unsafe { entry.add(vpn[i - 1]).as_ref().unwrap() };
        }

        // If we get here, we've exhausted all valid tables and haven't
        // found a leaf.
        None
    }

    /// Maps a virtual address to a physical address
    /// flags should contain only the following:
    ///   Read, Write, Execute, User, and/or Global
    /// flags MUST include one or more of the following:
    ///   Read, Write, Execute
    /// The valid bit automatically gets added
    pub fn map(
        &mut self,
        vaddr: VirtualAddress,
        paddr: PhysicalAddress,
        flags: PageEntryFlags,
        page_size: PageSize,
    ) {
        assert!(flags as usize & 0xE != 0);

        let vpn = vaddr.vpns();
        let ppn = paddr.ppns();
        let level = page_size.level();

        let mut v = &mut self.entries[vpn[2]];

        // Now, we're going to traverse the page table and set the bits properly. We expect the root
        // to be valid, however we're required to create anything beyond the root
        for i in (level..2).rev() {
            if v.is_invalid() {
                let mut mm = MEMORY_MANAGER.get().unwrap().lock();
                let page = mm.zallocate_pages(1).unwrap().as_addr();
                v.set_entry((page as usize >> 2) | PageEntryFlags::Valid as usize);
            }

            let entry = v.addr().as_mut_ptr::<PageEntry>();
            v = unsafe { entry.add(vpn[i]).as_mut().unwrap() };
        }

        // When we get here, we should be at VPN[0] and v should be pointing to our entry.
        // The entry structure is Figure 4.18 in the RISC-V Privileged Specification
        let entry = (ppn[2] << 28) as usize // PPN[2] = [53:28]
            | (ppn[1] << 19) as usize // PPN[1] = [27:19]
            | (ppn[0] << 10) as usize // PPN[0] = [18:10]
            | flags as usize // Specified bits, such as User, Read, Write, etc.
            | PageEntryFlags::Valid as usize;
        v.set_entry(entry);
    }

    /// Identity maps a page of memory
    pub fn identity_map(
        &mut self,
        addr: PhysicalAddress,
        flags: PageEntryFlags,
        page_size: PageSize,
    ) {
        // log::debug!("identity mapped {addr}");
        self.map(addr.as_addr().into(), addr, flags, page_size);
    }

    /// Identity maps a range of contiguous memory
    /// This assumes that start <= end
    pub fn identity_map_range(
        &mut self,
        start: PhysicalAddress,
        end: PhysicalAddress,
        flags: PageEntryFlags,
    ) {
        log::debug!("start: {start}, end: {end}");
        let mut mem_addr = start.as_addr() & !(PAGE_SIZE - 1);
        let num_pages = (align_val(end.as_addr(), 12) - mem_addr - 1) / PAGE_SIZE + 1;

        for _ in 0..num_pages {
            // FIXME: we can merge these page entries if possible into Size2MiB or larger entries
            self.identity_map(mem_addr.into(), flags, PageSize::Size4KiB);
            mem_addr += 1 << 12;
        }
    }

    /// Unmaps a page of memory at vaddr
    pub fn unmap(&mut self, vaddr: VirtualAddress) {
        let vpn = vaddr.vpns();

        // Now, we're going to traverse the page table and clear the bits
        let mut v = &mut self.entries[vpn[2]];
        for i in (0..2).rev() {
            if v.is_invalid() {
                // This is an invalid entry, page is already unmapped
                return;
            } else if v.is_leaf() {
                // This is a leaf, which can be at any level
                // In order to make this page unmapped, we need to clear the entry
                v.set_entry(0);
                return;
            }

            let entry = v.addr().as_mut_ptr::<PageEntry>();
            v = unsafe { entry.add(vpn[i]).as_mut().unwrap() };
        }

        // If we're here this is an unmapped page
        return;
    }

    /// Unmaps a range of contiguous memory
    /// This assumes that start <= end
    pub fn unmap_range(&mut self, start: VirtualAddress, end: VirtualAddress) {
        let mut mem_addr = start.as_addr() & !(PAGE_SIZE - 1);
        let num_pages = (align_val(end.as_addr(), 12) - mem_addr) / PAGE_SIZE;

        for _ in 0..num_pages {
            self.unmap(mem_addr.into());
            mem_addr += 1 << 12;
        }
    }

    /// Frees all memory associated with a table.
    /// NOTE: This does NOT free the table directly. This must be freed manually.
    fn destroy(&mut self) {
        for entry in &mut self.entries {
            entry.destroy()
        }
    }
}

#[repr(usize)]
#[derive(Clone, Copy, Debug)]
pub enum PageEntryFlags {
    None = 0,
    Valid = 1,
    Read = 1 << 1,
    Write = 1 << 2,
    Execute = 1 << 3,
    User = 1 << 4,
    Global = 1 << 5,
    Access = 1 << 6,
    Dirty = 1 << 7,

    // for convenience
    ReadWrite = Self::Read as usize | Self::Write as usize,
    ReadExecute = Self::Read as usize | Self::Execute as usize,
    ReadWriteExecute = Self::Read as usize | Self::Write as usize | Self::Execute as usize,
    UserReadWrite = Self::User as usize | Self::ReadWrite as usize,
    UserReadExecute = Self::User as usize | Self::ReadExecute as usize,
    UserReadWriteExecute = Self::User as usize | Self::ReadWriteExecute as usize,
}

struct PageEntry(usize);

impl PageEntry {
    fn is_valid(&self) -> bool {
        self.0 & PageEntryFlags::Valid as usize != 0
    }

    fn is_invalid(&self) -> bool {
        !self.is_valid()
    }

    fn is_leaf(&self) -> bool {
        self.0 & PageEntryFlags::ReadWriteExecute as usize != 0
    }

    fn is_branch(&self) -> bool {
        !self.is_leaf()
    }

    fn entry(&self) -> usize {
        self.0
    }

    fn set_entry(&mut self, entry: usize) {
        self.0 = entry;
    }

    fn clear_flag(&mut self, flag: PageEntryFlags) {
        self.0 &= !(flag as usize);
    }

    fn set_flag(&mut self, flag: PageEntryFlags) {
        self.0 |= flag as usize;
    }

    fn addr(&self) -> PhysicalAddress {
        ((self.entry() as usize & !0x3FF) << 2).into()
    }

    fn destroy(&mut self) {
        if self.is_valid() && self.is_branch() {
            // This is a valid entry so drill down and free
            let memaddr = self.addr();
            let table = memaddr.as_mut_ptr::<PageTable>();
            unsafe {
                (*table).destroy();
                let mut mm = MEMORY_MANAGER.get().unwrap().lock();
                mm.deallocate_pages(memaddr.into(), 0);
            }
        }
    }
}

// FIXME: PageTable should be integrated into MemoryManager *somehow*
pub static MEMORY_MANAGER: Once<Mutex<MemoryManager>> = Once::new();
pub static PAGE_TABLE: Once<Mutex<PhysicalAddress>> = Once::new();

pub fn init(start_addr: PhysicalAddress, page_count: usize) {
    let mut memory_manager = MemoryManager::new();

    unsafe {
        memory_manager.add_range(start_addr, page_count);
        PAGE_TABLE.call_once(|| Mutex::new(memory_manager.zallocate_pages(0).unwrap()));
    }

    MEMORY_MANAGER.call_once(|| Mutex::new(memory_manager));
}

/// Align (set to a multiple of some power of two)
/// This function always rounds up.
fn align_val(val: usize, order: usize) -> usize {
    let o = (1 << order) - 1;
    (val + o) & !o
}