.{memory: .{inb, outb, ins, alloc}, log} := @use("stn")

$ATA_PRIMARY_DATA := 0x1F0
$ATA_PRIMARY_ERR := 0x1F1
$ATA_PRIMARY_SECCOUNT := 0x1F2
$ATA_PRIMARY_LBA_LO := 0x1F3
$ATA_PRIMARY_LBA_MID := 0x1F4
$ATA_PRIMARY_LBA_HI := 0x1F5
$ATA_PRIMARY_DRIVE_HEAD := 0x1F6
$ATA_PRIMARY_COMM_REGSTAT := 0x1F7
$ATA_PRIMARY_ALTSTAT_DCR := 0x3F6

$STAT_ERR := 1 << 0
$STAT_DRQ := 1 << 3
$STAT_SRV := 1 << 4
$STAT_DF := 1 << 5
$STAT_RDY := 1 << 6
$STAT_BSY := 1 << 7

Drive := enum {Master, Slave}

select_drive := fn(drive: Drive): void {
	match drive {
		.Master => outb(ATA_PRIMARY_DRIVE_HEAD, 0xA0),
		.Slave => outb(ATA_PRIMARY_DRIVE_HEAD, 0xB0),
	}
}

identify := fn(drive: Drive): u8 {
	if inb(ATA_PRIMARY_COMM_REGSTAT) == 0xFF {
		log.error("(ata: drive not present) status=0xFF")
		return 0
	}

	select_drive(drive)
	inb(ATA_PRIMARY_COMM_REGSTAT)
	outb(ATA_PRIMARY_SECCOUNT, 0)
	inb(ATA_PRIMARY_COMM_REGSTAT)
	outb(ATA_PRIMARY_LBA_LO, 0)
	inb(ATA_PRIMARY_COMM_REGSTAT)
	outb(ATA_PRIMARY_LBA_MID, 0)
	inb(ATA_PRIMARY_COMM_REGSTAT)
	outb(ATA_PRIMARY_LBA_HI, 0)
	inb(ATA_PRIMARY_COMM_REGSTAT)
	outb(ATA_PRIMARY_COMM_REGSTAT, 0xEC)
	outb(ATA_PRIMARY_COMM_REGSTAT, 0xE7)

	status := inb(ATA_PRIMARY_COMM_REGSTAT)

	loop if (status & STAT_BSY) == 0 break else {
		// if DEBUG_PRINT log.printf("(ata: waiting for status) status={}", .(status), .{radix: 16, log: .Warn})
		status = inb(ATA_PRIMARY_COMM_REGSTAT)
	}

	if status == 0 {
		log.error("(ata: drive not present) status=0")
		return 0
	}

	loop if (status & STAT_BSY) == 0 break else {
		if DEBUG_PRINT log.printf("(ata: waiting for busy to end) status={}", .(status), .{radix: 16, log: .Warn})
		status = inb(ATA_PRIMARY_COMM_REGSTAT)
	}

	mid := inb(ATA_PRIMARY_LBA_MID)
	hi := inb(ATA_PRIMARY_LBA_HI)
	if (mid | hi) != 0 {
		log.error("the drive is not ata...?")
		return 0
	}

	loop if (status & (STAT_ERR | STAT_DRQ)) != 0 break else {
		if DEBUG_PRINT log.printf("(ata: waiting for ERR or DRQ) status={}", .(status), .{radix: 16, log: .Warn})
		status = inb(ATA_PRIMARY_COMM_REGSTAT)
	}

	if (status & STAT_ERR) != 0 {
		if DEBUG_PRINT log.printf("(ata: drive error) status={}", .(status), .{radix: 16, log: .Error})
		return 0
	}

	if DEBUG_PRINT log.printf("status={}", .(status), .{radix: 16})

	buffer := alloc(u16, 255)[0..255]
	read(buffer)

	if DEBUG_PRINT {
		if (buffer[83] & 1 << 10) != 0 {
			log.info("LBA48 mode supported")
			log.printf("{} 48 bit addressable sectors", *@as(^uint, @bitcast(buffer[100..].ptr)), .{})
		}
		log.print(buffer, .{})
	}

	return 0
}

read := fn(buffer: []u16): void {
	i := 0
	loop if i == buffer.len break else {
		buffer[i] = ins(ATA_PRIMARY_DATA)
		i += 1
	}
}

// inflates asm a lot
$DEBUG_PRINT := true

main := fn(): void {
	identify(.Master)
}