#[derive(Copy, Clone, Debug)]
/// A struct containing info about a PCI device.
pub struct PciDeviceInfo {
    pub header_type: u8,
    pub device: u8,
    pub bus: u8,
    pub device_id: DeviceID,
    pub full_class: PciFullClass,
    pub rev_id: u8,
}

/// Enumerate PCI devices and run initialisation routines on ones we support
pub fn init(device_tree: &mut DeviceTree) {
    device_tree
        .devices
        .insert("Unidentified PCI", alloc::vec![]);
    let mut devices = alloc::vec![];

    for bus in 0..=255 {
        for device in 0..32 {
            if let Some(device_info) = check_device(bus, device) {
                let vendor = device_info.device_id.vendor;
                let id = device_info.device_id.id;
                use Vendor::*;
                let (dev_type, dev_name) = match (vendor, id) {
                    (VMWareInc, 1029) => ("GPUs", "SVGAII PCI GPU"),
                    (Qemu, 4369) => ("GPUs", "QEMU VGA"),
                    (VirtIO, 4176) => ("GPUs", "VirtIO PCI GPU"),
                    (CirrusLogic, 184) => ("GPUs", "Cirrus SVGA"), //GD 5446?
                    (_, _) => ("Unidentified PCI", "UNKNOWN DEVICE"),
                };
                // let (dev_type, dev_name) = match device_info.full_class {
                //     PciFullClass::Unclassified_NonVgaCompatible => todo!(),
                //     PciFullClass::Unclassified_VgaCompatible => todo!(),

                //     PciFullClass::Display_VGA => ("GPUs", "VGA Device"),
                //     PciFullClass::Display_XGA => ("GPUs", "XGA Device"),
                //     PciFullClass::Display_3D => ("GPUs", "3D Device"),
                //     PciFullClass::Display_Other => ("GPUs", "Other"),

                //     _ => ("Unidentified PCI", "UNKNOWN DEVICE"),
                // };

                let mut dev = xml::XMLElement::new(dev_name);
                let mut pci_info = xml::XMLElement::new("PCI Info");
                pci_info.set_attribute("id", id);
                pci_info.set_attribute("device", device_info.device);
                pci_info.set_attribute("vendor", vendor);
                pci_info.set_attribute("bus", bus);
                pci_info.set_attribute("class", device_info.full_class);
                dev.set_child(pci_info);
                devices.push((dev_type, dev));
            }
        }
    }
    for (dev_type, dev) in devices {
        if let Some(abc) = device_tree.devices.get_mut(dev_type) {
            abc.push(dev);
        }
    }
}

pub fn check_device(bus: u8, device: u8) -> Option<PciDeviceInfo> {
    assert!(device < 32);
    let (device_id, vendor_id) = get_ids(bus, device, 0);
    if vendor_id == 0xFFFF {
        // Device doesn't exist
        return None;
    }

    let (reg2, addr) = unsafe { pci_config_read_2(bus, device, 0, 0x8) };
    log::info!("pci device-({}) addr {} is {}", device, addr, reg2);
    let class = ((reg2 >> 16) & 0x0000_FFFF) as u16;
    let pci_class = PciFullClass::from_u16(class);
    let header_type = get_header_type(bus, device, 0);

    Some(PciDeviceInfo {
        header_type,
        device,
        bus,
        device_id: DeviceID {
            vendor: vendor_id.into(),
            id:     device_id,
        },
        full_class: pci_class,
        rev_id: (reg2 & 0x0000_00FF) as u8,
    })
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct DeviceID {
    pub vendor: Vendor,
    pub id:     u16,
}
impl DeviceID {
    pub const fn new(vendor: Vendor, id: u16) -> Self {
        Self { vendor, id }
    }
}

#[derive(PartialEq, Debug, Copy, Clone, Eq)]
#[repr(u16)]
pub enum Vendor {
    ThreeDfxInteractiveInc = 0x121A,
    ThreeDLabs = 0x3D3D,
    AllianceSemiconductorCorp = 0x1142,
    ARKLogicInc = 0xEDD8,
    ATITechnologiesInc = 0x1002,
    AvanceLogicIncALI = 0x1005,
    ChipsandTechnologies = 0x102C,
    CirrusLogic = 0x1013,
    Compaq = 0x0E11,
    CyrixCorp = 0x1078,
    DiamondMultimediaSystems = 0x1092,
    DigitalEquipmentCorp = 0x1011,
    Iit = 0x1061,
    IntegratedMicroSolutionsInc = 0x10E0,
    IntelCorp = 0x8086,
    IntergraphicsSystems = 0x10EA,
    MacronixInc = 0x10D9,
    MatroxGraphicsInc = 0x102B,
    MiroComputersProductsAG = 0x1031,
    NationalSemiconductorCorp = 0x100B,
    NeoMagicCorp = 0x10C8,
    Number9ComputerCompany = 0x105D,
    NVidiaCorporation = 0x10DE,
    NVidiaSgsthomson = 0x12D2,
    OakTechnologyInc = 0x104E,
    Qemu = 0x1234,
    QuantumDesignsHKLtd = 0x1098,
    Real3D = 0x003D,
    Rendition = 0x1163,
    S3Inc = 0x5333,
    SierraSemiconductor = 0x10A8,
    SiliconIntegratedSystemsSiS = 0x1039,
    SiliconMotionInc = 0x126F,
    STBSystemsInc = 0x10B4,
    TexasInstruments = 0x104C,
    ToshibaAmericaInfoSystems = 0x1179,
    TridentMicrosystems = 0x1023,
    TsengLabsInc = 0x100C,
    TundraSemiconductorCorp = 0x10E3,
    VIATechnologiesInc = 0x1106,
    VirtIO = 0x1AF4,
    VMWareInc = 0x15AD,
    Weitek = 0x100E,
    Unknown(u16),
}

impl From<u16> for Vendor {
    fn from(vendor_id: u16) -> Self {
        use Vendor::*;
        match vendor_id {
            0x121A => ThreeDfxInteractiveInc,
            0x3D3D => ThreeDLabs,
            0x1142 => AllianceSemiconductorCorp,
            0xEDD8 => ARKLogicInc,
            0x1002 => ATITechnologiesInc,
            0x1005 => AvanceLogicIncALI,
            0x102C => ChipsandTechnologies,
            0x1013 => CirrusLogic,
            0x0E11 => Compaq,
            0x1078 => CyrixCorp,
            0x1092 => DiamondMultimediaSystems,
            0x1011 => DigitalEquipmentCorp,
            0x1061 => Iit,
            0x10E0 => IntegratedMicroSolutionsInc,
            0x8086 => IntelCorp,
            0x10EA => IntergraphicsSystems,
            0x10D9 => MacronixInc,
            0x102B => MatroxGraphicsInc,
            0x1031 => MiroComputersProductsAG,
            0x100B => NationalSemiconductorCorp,
            0x10C8 => NeoMagicCorp,
            0x105D => Number9ComputerCompany,
            0x10DE => NVidiaCorporation,
            0x12D2 => NVidiaSgsthomson,
            0x104E => OakTechnologyInc,
            0x1234 => Qemu,
            0x1098 => QuantumDesignsHKLtd,
            0x003D => Real3D,
            0x1163 => Rendition,
            0x5333 => S3Inc,
            0x10A8 => SierraSemiconductor,
            0x1039 => SiliconIntegratedSystemsSiS,
            0x126F => SiliconMotionInc,
            0x10B4 => STBSystemsInc,
            0x104C => TexasInstruments,
            0x1179 => ToshibaAmericaInfoSystems,
            0x1023 => TridentMicrosystems,
            0x100C => TsengLabsInc,
            0x10E3 => TundraSemiconductorCorp,
            0x1106 => VIATechnologiesInc,
            0x1AF4 => VirtIO,
            0x15AD => VMWareInc,
            0x100E => Weitek,
            id => Unknown(id),
        }
    }
}

impl Into<u16> for Vendor {
    fn into(self) -> u16 {
        use Vendor::*;
        match self {
            ThreeDfxInteractiveInc => 0x121A,
            ThreeDLabs => 0x3D3D,
            AllianceSemiconductorCorp => 0x1142,
            ARKLogicInc => 0xEDD8,
            ATITechnologiesInc => 0x1002,
            AvanceLogicIncALI => 0x1005,
            ChipsandTechnologies => 0x102C,
            CirrusLogic => 0x1013,
            Compaq => 0x0E11,
            CyrixCorp => 0x1078,
            DiamondMultimediaSystems => 0x1092,
            DigitalEquipmentCorp => 0x1011,
            Iit => 0x1061,
            IntegratedMicroSolutionsInc => 0x10E0,
            IntelCorp => 0x8086,
            IntergraphicsSystems => 0x10EA,
            MacronixInc => 0x10D9,
            MatroxGraphicsInc => 0x102B,
            MiroComputersProductsAG => 0x1031,
            NationalSemiconductorCorp => 0x100B,
            NeoMagicCorp => 0x10C8,
            Number9ComputerCompany => 0x105D,
            NVidiaCorporation => 0x10DE,
            NVidiaSgsthomson => 0x12D2,
            OakTechnologyInc => 0x104E,
            Qemu => 0x1234,
            QuantumDesignsHKLtd => 0x1098,
            Real3D => 0x003D,
            Rendition => 0x1163,
            S3Inc => 0x5333,
            SierraSemiconductor => 0x10A8,
            SiliconIntegratedSystemsSiS => 0x1039,
            SiliconMotionInc => 0x126F,
            STBSystemsInc => 0x10B4,
            TexasInstruments => 0x104C,
            ToshibaAmericaInfoSystems => 0x1179,
            TridentMicrosystems => 0x1023,
            TsengLabsInc => 0x100C,
            TundraSemiconductorCorp => 0x10E3,
            VIATechnologiesInc => 0x1106,
            VirtIO => 0x1AF4,
            VMWareInc => 0x15AD,
            Weitek => 0x100E,
            Unknown(id) => id,
        }
    }
}

impl Display for Vendor {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        use Vendor::*;

        match self {
            Qemu => write!(f, "QEMU (0x1234)"),
            VirtIO => write!(f, "VirtIO (0x1AF4)"),
            VMWareInc => write!(f, "VMWARE (0x15AD)"),
            S3Inc => write!(f, "S3 Incorporated (0x5333)"),
            IntelCorp => write!(f, "Intel Corp. (0x8086)"),
            ATITechnologiesInc => write!(f, "ATI (0x1002)"),
            Unknown(id) => write!(f, "Unknown ({:#6})", id),
            other => write!(f, "{other:?}"),
        }?;

        Ok(())
    }
}

use core::fmt::Display;

use {crate::device_tree::DeviceTree, x86_64::instructions::port::Port};

#[allow(non_camel_case_types, dead_code)]
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
/// Class specification for a PCI device
pub enum PciClass {
    Unclassified = 0x00,
    MassStorage = 0x01,
    Network = 0x02,
    Display = 0x03,
    Multimedia = 0x04,
    Memory = 0x05,
    Bridge = 0x06,
    Unknown = 0xFF,
}

impl From<u8> for PciClass {
    /// Convert a u8 into the corresponding PciClass
    fn from(n: u8) -> Self {
        use PciClass::*;
        match n {
            0x00 => Unclassified,
            0x01 => MassStorage,
            0x02 => Network,
            0x03 => Display,
            0x04 => Multimedia,
            0x05 => Memory,
            0x06 => Bridge,
            _ => Unknown,
        }
    }
}

#[allow(non_camel_case_types, dead_code)]
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
/// Full class specification (type and subtype) for a PCI device.
///
/// Uses non-camel-case types for readability.
pub enum PciFullClass {
    Unclassified_NonVgaCompatible = 0x0000,
    Unclassified_VgaCompatible = 0x0001,

    MassStorage_ScsiBus = 0x0100,
    MassStorage_IDE = 0x0101,
    MassStorage_Floppy = 0x0102,
    MassStorage_IpiBus = 0x0103,
    MassStorage_RAID = 0x0104,
    MassStorage_ATA = 0x0105,
    MassStorage_SATA = 0x0106,
    MassStorage_SerialSCSI = 0x0107,
    MassStorage_NVM = 0x0108,
    MassStorage_Other = 0x0180,

    Network_Ethernet = 0x0200,
    Network_TokenRing = 0x0201,
    Network_FDDI = 0x0202,
    Network_ATM = 0x0203,
    Network_ISDN = 0x0204,
    Network_WorldFlip = 0x0205,
    Network_PICMG = 0x0206,
    Network_Infiniband = 0x0207,
    Network_Fabric = 0x0208,
    Network_Other = 0x0280,

    Display_VGA = 0x0300,
    Display_XGA = 0x0301,
    Display_3D = 0x0302,
    Display_Other = 0x0380,

    Multimedia_Video = 0x0400,
    Multimedia_AudioController = 0x0401,
    Multimedia_Telephony = 0x0402,
    Multimedia_AudioDevice = 0x0403,
    Multimedia_Other = 0x0480,

    Memory_RAM = 0x0500,
    Memory_Flash = 0x0501,
    Memory_Other = 0x0580,

    Bridge_Host = 0x0600,
    Bridge_ISA = 0x0601,
    Bridge_EISA = 0x0602,
    Bridge_MCA = 0x0603,
    Bridge_PciToPci = 0x0604,
    Bridge_PCMCIA = 0x0605,
    Bridge_NuBus = 0x0606,
    Bridge_CardBus = 0x0607,
    Bridge_RACEway = 0x0608,
    Bridge_PciToPciSemiTransparent = 0x0609,
    Bridge_InfinibandToPci = 0x060A,
    Bridge_Other = 0x0680,

    Unknown = 0xFFFF,
}

impl PciFullClass {
    // listen, i know this sucks, but i didn't want to include
    // `num`, `num-traits` and `num-derive` as dependencies for
    // this crate just for a convenience function
    /// Convert a u16 into the corresponding PciFullClass
    pub fn from_u16(n: u16) -> PciFullClass {
        match n {
            0x0000 => PciFullClass::Unclassified_NonVgaCompatible,
            0x0001 => PciFullClass::Unclassified_VgaCompatible,

            0x0100 => PciFullClass::MassStorage_ScsiBus,
            0x0101 => PciFullClass::MassStorage_IDE,
            0x0102 => PciFullClass::MassStorage_Floppy,
            0x0103 => PciFullClass::MassStorage_IpiBus,
            0x0104 => PciFullClass::MassStorage_RAID,
            0x0105 => PciFullClass::MassStorage_ATA,
            0x0106 => PciFullClass::MassStorage_SATA,
            0x0107 => PciFullClass::MassStorage_SerialSCSI,
            0x0108 => PciFullClass::MassStorage_NVM,
            0x0180 => PciFullClass::MassStorage_Other,

            0x0200 => PciFullClass::Network_Ethernet,
            0x0201 => PciFullClass::Network_TokenRing,
            0x0202 => PciFullClass::Network_FDDI,
            0x0203 => PciFullClass::Network_ATM,
            0x0204 => PciFullClass::Network_ISDN,
            0x0205 => PciFullClass::Network_WorldFlip,
            0x0206 => PciFullClass::Network_PICMG,
            0x0207 => PciFullClass::Network_Infiniband,
            0x0208 => PciFullClass::Network_Fabric,
            0x0280 => PciFullClass::Network_Other,

            0x0300 => PciFullClass::Display_VGA,
            0x0301 => PciFullClass::Display_XGA,
            0x0302 => PciFullClass::Display_3D,
            0x0380 => PciFullClass::Display_Other,

            0x0400 => PciFullClass::Multimedia_Video,
            0x0401 => PciFullClass::Multimedia_AudioController,
            0x0402 => PciFullClass::Multimedia_Telephony,
            0x0403 => PciFullClass::Multimedia_AudioDevice,
            0x0480 => PciFullClass::Multimedia_Other,

            0x0500 => PciFullClass::Memory_RAM,
            0x0501 => PciFullClass::Memory_Flash,
            0x0580 => PciFullClass::Memory_Other,

            0x0600 => PciFullClass::Bridge_Host,
            0x0601 => PciFullClass::Bridge_ISA,
            0x0602 => PciFullClass::Bridge_EISA,
            0x0603 => PciFullClass::Bridge_MCA,
            0x0604 => PciFullClass::Bridge_PciToPci,
            0x0605 => PciFullClass::Bridge_PCMCIA,
            0x0606 => PciFullClass::Bridge_NuBus,
            0x0607 => PciFullClass::Bridge_CardBus,
            0x0608 => PciFullClass::Bridge_RACEway,
            0x0609 => PciFullClass::Bridge_PciToPciSemiTransparent,
            0x060A => PciFullClass::Bridge_InfinibandToPci,
            0x0680 => PciFullClass::Bridge_Other,

            _ => PciFullClass::Unknown,
        }
    }

    /// Convert a PciFullClass to its u16 representation
    pub fn as_u16(&self) -> u16 {
        *self as u16
    }
}

impl From<u16> for PciFullClass {
    /// Convert a u16 into the corresponding PciFullClass
    fn from(n: u16) -> Self {
        Self::from_u16(n)
    }
}

impl Display for PciFullClass {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        write!(f, "{:?} ({:#06X})", self, self.as_u16())?;

        Ok(())
    }
}

unsafe fn pci_config_read(bus: u8, device: u8, func: u8, offset: u8) -> u32 {
    let bus = bus as u32;
    let device = device as u32;
    let func = func as u32;
    let offset = offset as u32;
    // construct address param
    let address = (bus << 16) | (device << 11) | (func << 8) | (offset & 0xFC) | 0x8000_0000;
    // write address
    Port::new(0xCF8).write(address);

    // read data
    Port::new(0xCFC).read()
}

unsafe fn pci_config_read_2(bus: u8, device: u8, func: u8, offset: u8) -> (u32, u32) {
    let bus = bus as u32;
    let device = device as u32;
    let func = func as u32;
    let offset = offset as u32;
    // construct address param
    let address = (bus << 16) | (device << 11) | (func << 8) | (offset & 0xFC) | 0x8000_0000;
    // write address
    Port::new(0xCF8).write(address);

    // read data
    (Port::new(0xCFC).read(), address)
}

unsafe fn pci_config_write(bus: u8, device: u8, func: u8, offset: u8, value: u32) {
    let bus = bus as u32;
    let device = device as u32;
    let func = func as u32;
    let offset = offset as u32;
    // construct address param
    let address =
        ((bus << 16) | (device << 11) | (func << 8) | (offset & 0xFC) | 0x8000_0000) as u32;

    // write address
    Port::new(0xCF8).write(address);

    // write data
    Port::new(0xCFC).write(value);
}

fn get_header_type(bus: u8, device: u8, function: u8) -> u8 {
    assert!(device < 32);
    assert!(function < 8);
    let res = unsafe { pci_config_read(bus, device, function, 0x0C) };
    ((res >> 16) & 0xFF) as u8
}

fn get_ids(bus: u8, device: u8, function: u8) -> (u16, u16) {
    assert!(device < 32);
    assert!(function < 8);
    let res = unsafe { pci_config_read(bus, device, function, 0) };
    let dev_id = ((res >> 16) & 0xFFFF) as u16;
    let vnd_id = (res & 0xFFFF) as u16;
    (dev_id, vnd_id)
}