#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct Token {
    pub kind:  TokenKind,
    pub start: u32,
    pub end:   u32,
}

impl Token {
    pub fn range(&self) -> std::ops::Range<usize> {
        self.start as usize..self.end as usize
    }

    pub fn len(&self) -> u32 {
        self.end - self.start
    }
}

macro_rules! gen_token_kind {
    ($(
        #[$atts:meta])*
        $vis:vis enum $name:ident {
            #[patterns] $(
                $pattern:ident,
            )*
            #[keywords] $(
                $keyword:ident = $keyword_lit:literal,
            )*
            #[punkt] $(
                $punkt:ident = $punkt_lit:literal,
            )*
            #[ops] $(
                #[$prec:ident] $(
                    $op:ident = $op_lit:literal $(=> $assign:ident)?,
                )*
            )*
        }
    ) => {
        impl std::fmt::Display for $name {
            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
                let s = match *self {
                    $( Self::$pattern => concat!('<', stringify!($pattern), '>'), )*

                    $( Self::$keyword => stringify!($keyword_lit), )*
                    $( Self::$punkt   => stringify!($punkt_lit),   )*
                    $($( Self::$op    => $op_lit,
                      $(Self::$assign => concat!($op_lit, "="),)?)*)*
                };
                f.write_str(s)
            }
        }

        impl $name {
            #[inline(always)]
            pub fn precedence(&self) -> Option<u8> {
                Some(match self {
                    $($(Self::$op => ${ignore($prec)} ${index(1)},
                      $(Self::$assign => 0,)?)*)*
                    _ => return None,
                } + 1)
            }

            #[inline(always)]
            fn from_ident(ident: &[u8]) -> Self {
                match ident {
                    $($keyword_lit => Self::$keyword,)*
                    _ => Self::Ident,
                }
            }

            pub fn assign_op(&self) -> Option<Self> {
                Some(match self {
                    $($($(Self::$assign => Self::$op,)?)*)*
                    _ => return None,
                })
            }
        }

        #[derive(Debug, PartialEq, Eq, Clone, Copy)]
        $vis enum $name {
            $( $pattern, )*
            $( $keyword, )*
            $( $punkt,   )*
            $($( $op, $($assign,)?  )*)*
        }
    };
}

gen_token_kind! {
    pub enum TokenKind {
        #[patterns]
        Ident,
        Number,
        Eof,
        Error,
        Driective,
        #[keywords]
        Return   = b"return",
        If       = b"if",
        Else     = b"else",
        Loop     = b"loop",
        Break    = b"break",
        Continue = b"continue",
        Fn       = b"fn",
        Struct   = b"struct",
        True     = b"true",
        #[punkt]
        LParen = "(",
        RParen = ")",
        LBrace = "{",
        RBrace = "}",
        Semi   = ";",
        Colon  = ":",
        Comma  = ",",
        Dot    = ".",
        Ctor   = ".{",
        Tupl   = ".(",
        #[ops]
        #[prec]
        Decl   = ":=",
        Assign = "=",
        #[prec]
        Or = "||",
        #[prec]
        And = "&&",
        #[prec]
        Bor = "|" => BorAss,
        #[prec]
        Xor = "^" => XorAss,
        #[prec]
        Band = "&" => BandAss,
        #[prec]
        Eq = "==",
        Ne = "!=",
        #[prec]
        Le = "<=",
        Ge = ">=",
        Lt = "<",
        Gt = ">",
        #[prec]
        Shl = "<<" => ShlAss,
        Shr = ">>" => ShrAss,
        #[prec]
        Add = "+" => AddAss,
        Sub = "-" => SubAss,
        #[prec]
        Mul = "*" => MulAss,
        Div = "/" => DivAss,
        Mod = "%" => ModAss,
    }
}

pub struct Lexer<'a> {
    pos:   u32,
    bytes: &'a [u8],
}

impl<'a> Lexer<'a> {
    pub fn new(input: &'a str) -> Self {
        Self {
            pos:   0,
            bytes: input.as_bytes(),
        }
    }

    pub fn slice(&self, tok: std::ops::Range<usize>) -> &'a str {
        unsafe { std::str::from_utf8_unchecked(&self.bytes[tok]) }
    }

    fn peek(&self) -> Option<u8> {
        self.bytes.get(self.pos as usize).copied()
    }

    fn advance(&mut self) -> Option<u8> {
        let c = self.peek()?;
        self.pos += 1;
        Some(c)
    }

    pub fn next(&mut self) -> Token {
        use TokenKind as T;
        loop {
            let mut start = self.pos;

            let Some(c) = self.advance() else {
                return Token {
                    kind: T::Eof,
                    start,
                    end: self.pos,
                };
            };

            let kind = match c {
                b'\n' | b'\r' | b'\t' | b' ' => continue,
                b'0'..=b'9' => {
                    while let Some(b'0'..=b'9') = self.peek() {
                        self.advance();
                    }
                    T::Number
                }
                c @ (b'a'..=b'z' | b'A'..=b'Z' | b'_' | b'@') => {
                    while let Some(b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_') = self.peek() {
                        self.advance();
                    }

                    if c == b'@' {
                        start += 1;
                        T::Driective
                    } else {
                        let ident = &self.bytes[start as usize..self.pos as usize];
                        T::from_ident(ident)
                    }
                }
                b':' if self.advance_if(b'=') => T::Decl,
                b':' => T::Colon,
                b',' => T::Comma,
                b'.' if self.advance_if(b'{') => T::Ctor,
                b'.' if self.advance_if(b'(') => T::Tupl,
                b'.' => T::Dot,
                b';' => T::Semi,
                b'!' if self.advance_if(b'=') => T::Ne,
                b'=' if self.advance_if(b'=') => T::Eq,
                b'=' => T::Assign,
                b'<' if self.advance_if(b'=') => T::Le,
                b'<' if self.advance_if(b'<') => match self.advance_if(b'=') {
                    true => T::ShlAss,
                    false => T::Shl,
                },
                b'<' => T::Lt,
                b'>' if self.advance_if(b'=') => T::Ge,
                b'>' if self.advance_if(b'>') => match self.advance_if(b'=') {
                    true => T::ShrAss,
                    false => T::Shr,
                },
                b'>' => T::Gt,
                b'+' if self.advance_if(b'=') => T::AddAss,
                b'+' => T::Add,
                b'-' if self.advance_if(b'=') => T::SubAss,
                b'-' => T::Sub,
                b'*' if self.advance_if(b'=') => T::MulAss,
                b'*' => T::Mul,
                b'/' if self.advance_if(b'=') => T::DivAss,
                b'/' => T::Div,
                b'%' if self.advance_if(b'=') => T::ModAss,
                b'%' => T::Mod,
                b'&' if self.advance_if(b'=') => T::BandAss,
                b'&' if self.advance_if(b'&') => T::And,
                b'&' => T::Band,
                b'^' if self.advance_if(b'=') => T::XorAss,
                b'^' => T::Xor,
                b'|' if self.advance_if(b'=') => T::BorAss,
                b'|' if self.advance_if(b'|') => T::Or,
                b'|' => T::Bor,
                b'(' => T::LParen,
                b')' => T::RParen,
                b'{' => T::LBrace,
                b'}' => T::RBrace,
                _ => T::Error,
            };

            return Token {
                kind,
                start,
                end: self.pos,
            };
        }
    }

    fn advance_if(&mut self, arg: u8) -> bool {
        if self.peek() == Some(arg) {
            self.advance();
            true
        } else {
            false
        }
    }

    pub fn line_col(&self, pos: u32) -> (usize, usize) {
        line_col(self.bytes, pos)
    }
}

pub fn line_col(bytes: &[u8], mut start: u32) -> (usize, usize) {
    bytes
        .split(|&b| b == b'\n')
        .enumerate()
        .find_map(|(i, line)| {
            if start < line.len() as u32 {
                return Some((i + 1, start as usize + 1));
            }
            start -= line.len() as u32 + 1;
            None
        })
        .unwrap_or((1, 1))
}

impl<'a> Iterator for Lexer<'a> {
    type Item = Token;

    fn next(&mut self) -> Option<Self::Item> {
        use TokenKind as T;
        loop {
            let mut start = self.pos;
            let kind = match self.advance()? {
                b'\n' | b'\r' | b'\t' | b' ' => continue,
                b'0'..=b'9' => {
                    while let Some(b'0'..=b'9') = self.peek() {
                        self.advance();
                    }
                    T::Number
                }
                c @ (b'a'..=b'z' | b'A'..=b'Z' | b'_' | b'@') => {
                    while let Some(b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_') = self.peek() {
                        self.advance();
                    }

                    if c == b'@' {
                        start += 1;
                        T::Driective
                    } else {
                        let ident = &self.bytes[start as usize..self.pos as usize];
                        T::from_ident(ident)
                    }
                }
                b':' if self.advance_if(b'=') => T::Decl,
                b':' => T::Colon,
                b',' => T::Comma,
                b'.' if self.advance_if(b'{') => T::Ctor,
                b'.' if self.advance_if(b'(') => T::Tupl,
                b'.' => T::Dot,
                b';' => T::Semi,
                b'!' if self.advance_if(b'=') => T::Ne,
                b'=' if self.advance_if(b'=') => T::Eq,
                b'=' => T::Assign,
                b'<' if self.advance_if(b'=') => T::Le,
                b'<' => T::Lt,
                b'>' if self.advance_if(b'=') => T::Ge,
                b'>' => T::Gt,
                b'+' => T::Add,
                b'-' => T::Sub,
                b'*' => T::Mul,
                b'/' => T::Div,
                b'&' => T::Band,
                b'(' => T::LParen,
                b')' => T::RParen,
                b'{' => T::LBrace,
                b'}' => T::RBrace,
                _ => T::Error,
            };

            return Some(Token {
                kind,
                start,
                end: self.pos,
            });
        }
    }
}

#[cfg(test)]
mod tests {
    fn lex(input: &'static str, output: &mut String) {
        use {
            super::{Lexer, TokenKind as T},
            std::fmt::Write,
        };
        let mut lexer = Lexer::new(input);
        loop {
            let token = lexer.next();
            writeln!(output, "{:?} {:?}", token.kind, &input[token.range()],).unwrap();
            if token.kind == T::Eof {
                break;
            }
        }
    }

    crate::run_tests! { lex:
        empty => "";
        whitespace => " \t\n\r";
        example => include_str!("../examples/main_fn.hb");
        arithmetic => include_str!("../examples/arithmetic.hb");
    }
}