use {
    spin::Lazy,
    x86_64::{
        structures::{
            gdt::{Descriptor, GlobalDescriptorTable, SegmentSelector},
            tss::TaskStateSegment,
        },
        VirtAddr,
    },
};

pub const DOUBLE_FAULT_IX: u16 = 0;

const STACK_SIZE: usize = 5 * 1024;
const STACK_ALIGNMENT: usize = 4096;

pub unsafe fn init() {
    use x86_64::instructions::{
        segmentation::{Segment, CS, SS},
        tables::load_tss,
    };

    log::trace!("Initialising GDT");
    GDT.0.load();
    CS::set_reg(GDT.1.kcode);
    SS::set_reg(GDT.1.kdata);
    load_tss(GDT.1.tss);
}

struct Selectors {
    kcode: SegmentSelector,
    kdata: SegmentSelector,
    tss:   SegmentSelector,
}

static TSS: Lazy<TaskStateSegment> = Lazy::new(|| {
    let mut tss = TaskStateSegment::new();

    let stack_ptr = unsafe {
        let layout = alloc::alloc::Layout::from_size_align(STACK_SIZE, STACK_ALIGNMENT)
            .expect("Failed to create stack layout");
        let stack = alloc::alloc::alloc_zeroed(layout);
        VirtAddr::from_ptr(stack) + STACK_SIZE as u64
    };

    tss.interrupt_stack_table[usize::from(DOUBLE_FAULT_IX)] = stack_ptr;
    tss
});

static GDT: Lazy<(GlobalDescriptorTable, Selectors)> = Lazy::new(|| {
    let mut gdt = GlobalDescriptorTable::new();
    let sels = Selectors {
        kcode: gdt.append(Descriptor::kernel_code_segment()),
        kdata: gdt.append(Descriptor::kernel_data_segment()),
        tss:   gdt.append(Descriptor::tss_segment(&TSS)),
    };
    (gdt, sels)
});