diff --git a/hblang/command-help.txt b/hblang/command-help.txt new file mode 100644 index 0000000..9850393 --- /dev/null +++ b/hblang/command-help.txt @@ -0,0 +1,3 @@ +--fmt - format all source files +--fmt-current - format mentioned file + --fmt-stdout - dont write the formatted file but print it diff --git a/hblang/src/codegen.rs b/hblang/src/codegen.rs index 30e558c..2e3a905 100644 --- a/hblang/src/codegen.rs +++ b/hblang/src/codegen.rs @@ -1580,7 +1580,9 @@ impl Codegen { }); Some(Value::void()) } - E::Call { func: fast, args } => { + E::Call { + func: fast, args, .. + } => { let func_ty = self.ty(fast); let ty::Kind::Func(mut func_id) = func_ty.expand() else { self.report(fast.pos(), "can't call this, maybe in the future"); diff --git a/hblang/src/main.rs b/hblang/src/main.rs index fd63113..3771633 100644 --- a/hblang/src/main.rs +++ b/hblang/src/main.rs @@ -1,12 +1,56 @@ fn main() -> std::io::Result<()> { - let root = std::env::args() - .nth(1) - .unwrap_or_else(|| "main.hb".to_string()); + let args = std::env::args().collect::>(); + let args = args.iter().map(String::as_str).collect::>(); + let root = args.get(1).copied().unwrap_or("main.hb"); - let parsed = hblang::parse_from_fs(1, &root)?; - let mut codegen = hblang::codegen::Codegen::default(); - codegen.files = parsed; + if args.contains(&"--help") || args.contains(&"-h") { + println!("Usage: hbc [OPTIONS...] "); + println!(include_str!("../command-help.txt")); + return Err(std::io::ErrorKind::Other.into()); + } - codegen.generate(); - codegen.dump(&mut std::io::stdout()) + let parsed = hblang::parse_from_fs(1, root)?; + + fn format_to_stdout(ast: hblang::parser::Ast) -> std::io::Result<()> { + let source = std::fs::read_to_string(&*ast.path)?; + hblang::parser::with_fmt_source(&source, || { + for expr in ast.exprs() { + use std::io::Write; + writeln!(std::io::stdout(), "{expr}")?; + } + std::io::Result::Ok(()) + }) + } + + fn format_ast(ast: hblang::parser::Ast) -> std::io::Result<()> { + let source = std::fs::read_to_string(&*ast.path)?; + let mut output = Vec::new(); + hblang::parser::with_fmt_source(&source, || { + for expr in ast.exprs() { + use std::io::Write; + writeln!(output, "{expr}")?; + } + std::io::Result::Ok(()) + })?; + + std::fs::write(&*ast.path, output)?; + + Ok(()) + } + + if args.contains(&"--fmt") { + for parsed in parsed { + format_ast(parsed)?; + } + } else if args.contains(&"--fmt-current") { + format_to_stdout(parsed.into_iter().next().unwrap())?; + } else { + let mut codegen = hblang::codegen::Codegen::default(); + codegen.files = parsed; + + codegen.generate(); + codegen.dump(&mut std::io::stdout())?; + } + + Ok(()) } diff --git a/hblang/src/parser.rs b/hblang/src/parser.rs index 28ec64d..436fe63 100644 --- a/hblang/src/parser.rs +++ b/hblang/src/parser.rs @@ -58,15 +58,16 @@ struct ScopeIdent { } pub struct Parser<'a, 'b> { - path: &'b str, - loader: Loader<'b>, - lexer: Lexer<'b>, - arena: &'b Arena<'a>, - token: Token, - symbols: &'b mut Symbols, - ns_bound: usize, - idents: Vec, - captured: Vec, + path: &'b str, + loader: Loader<'b>, + lexer: Lexer<'b>, + arena: &'b Arena<'a>, + token: Token, + symbols: &'b mut Symbols, + ns_bound: usize, + trailing_sep: bool, + idents: Vec, + captured: Vec, } impl<'a, 'b> Parser<'a, 'b> { @@ -80,6 +81,7 @@ impl<'a, 'b> Parser<'a, 'b> { arena, symbols, ns_bound: 0, + trailing_sep: false, idents: Vec::new(), captured: Vec::new(), } @@ -384,6 +386,7 @@ impl<'a, 'b> Parser<'a, 'b> { T::LParen => Expr::Call { func: self.arena.alloc(expr), args: self.collect_list(T::Comma, T::RParen, Self::expr), + trailing_comma: std::mem::take(&mut self.trailing_sep), }, T::Ctor => E::Ctor { pos: token.start, @@ -464,7 +467,7 @@ impl<'a, 'b> Parser<'a, 'b> { self.collect(|s| { s.advance_if(end).not().then(|| { let val = f(s); - s.advance_if(delim); + s.trailing_sep = s.advance_if(delim); val }) }) @@ -600,6 +603,7 @@ generate_expr! { Call { func: &'a Self, args: &'a [Self], + trailing_comma: bool, }, Return { pos: Pos, @@ -686,6 +690,17 @@ impl<'a> Poser for &Expr<'a> { } } +thread_local! { + static FMT_SOURCE: Cell<*const str> = const { Cell::new("") }; +} + +pub fn with_fmt_source(source: &str, f: impl FnOnce() -> T) -> T { + FMT_SOURCE.with(|s| s.set(source)); + let r = f(); + FMT_SOURCE.with(|s| s.set("")); + r +} + impl<'a> std::fmt::Display for Expr<'a> { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { thread_local! { @@ -708,6 +723,33 @@ impl<'a> std::fmt::Display for Expr<'a> { write!(f, "{end}") } + fn fmt_trailing_list( + f: &mut std::fmt::Formatter, + end: &str, + list: &[T], + fmt: impl Fn(&T, &mut std::fmt::Formatter) -> std::fmt::Result, + ) -> std::fmt::Result { + writeln!(f)?; + INDENT.with(|i| i.set(i.get() + 1)); + let res = (|| { + for stmt in list { + for _ in 0..INDENT.with(|i| i.get()) { + write!(f, "\t")?; + } + fmt(stmt, f)?; + writeln!(f, ",")?; + } + Ok(()) + })(); + INDENT.with(|i| i.set(i.get() - 1)); + + for _ in 0..INDENT.with(|i| i.get()) { + write!(f, "\t")?; + } + write!(f, "{end}")?; + res + } + macro_rules! impl_parenter { ($($name:ident => $pat:pat,)*) => { $( @@ -732,8 +774,21 @@ impl<'a> std::fmt::Display for Expr<'a> { Consecutive => Expr::UnOp { .. }, } + { + let source = unsafe { &*FMT_SOURCE.with(|s| s.get()) }; + let pos = self.pos(); + + if let Some(before) = source.get(..pos as usize) { + let trailing_whitespace = &before[before.trim_end().len()..]; + let ncount = trailing_whitespace.chars().filter(|&c| c == '\n').count(); + if ncount > 1 { + writeln!(f)?; + } + } + } + match *self { - Self::Comment { literal, .. } => write!(f, "{literal}"), + Self::Comment { literal, .. } => write!(f, "{}", literal.trim_end()), Self::Mod { path, .. } => write!(f, "@mod(\"{path}\")"), Self::Field { target, field } => write!(f, "{}.{field}", Postfix(target)), Self::Directive { name, args, .. } => { @@ -787,28 +842,24 @@ impl<'a> std::fmt::Display for Expr<'a> { fmt_list(f, "", args, |arg, f| write!(f, "{}: {}", arg.name, arg.ty))?; write!(f, "): {ret} {body}") } - Self::Call { func, args } => { + Self::Call { + func, + args, + trailing_comma, + } => { write!(f, "{}(", Postfix(func))?; - fmt_list(f, ")", args, std::fmt::Display::fmt) + if trailing_comma { + fmt_trailing_list(f, ")", args, std::fmt::Display::fmt) + } else { + fmt_list(f, ")", args, std::fmt::Display::fmt) + } } Self::Return { val: Some(val), .. } => write!(f, "return {val};"), Self::Return { val: None, .. } => write!(f, "return;"), Self::Ident { name, .. } => write!(f, "{name}"), Self::Block { stmts, .. } => { - writeln!(f, "{{")?; - INDENT.with(|i| i.set(i.get() + 1)); - let res = (|| { - for stmt in stmts { - for _ in 0..INDENT.with(|i| i.get()) { - write!(f, " ")?; - } - writeln!(f, "{stmt}")?; - } - Ok(()) - })(); - INDENT.with(|i| i.set(i.get() - 1)); - write!(f, "}}")?; - res + write!(f, "{{")?; + fmt_trailing_list(f, "}", stmts, std::fmt::Display::fmt) } Self::Number { value, .. } => write!(f, "{value}"), Self::Bool { value, .. } => write!(f, "{value}"), @@ -1109,3 +1160,51 @@ impl Drop for ArenaChunk { } } } + +#[cfg(test)] +mod test { + fn format(ident: &str, input: &str) { + let ast = super::Ast::new(ident, input, &super::no_loader); + let mut output = String::new(); + super::with_fmt_source(input, || { + for expr in ast.exprs() { + use std::fmt::Write; + writeln!(output, "{expr}").unwrap(); + } + }); + + let input_path = format!("formatter_{ident}.expected"); + let output_path = format!("formatter_{ident}.actual"); + std::fs::write(&input_path, input).unwrap(); + std::fs::write(&output_path, output).unwrap(); + + let success = std::process::Command::new("diff") + .arg("-u") + .arg("--color") + .arg(&input_path) + .arg(&output_path) + .status() + .unwrap() + .success(); + std::fs::remove_file(&input_path).unwrap(); + std::fs::remove_file(&output_path).unwrap(); + assert!(success, "test failed"); + } + + macro_rules! test { + ($($name:ident => $input:expr;)*) => {$( + #[test] + fn $name() { + format(stringify!($name), $input); + } + )*}; + } + + test! { + comments => "// comment\n// comment\n\n// comment\n\n\ + /* comment */\n/* comment */\n\n/* comment */\n"; + some_ordinary_code => "loft := fn(): int return loft(1, 2, 3);\n"; + some_arg_per_line_code => "loft := fn(): int return loft(\ + \n\t1,\n\t2,\n\t3,\n);\n"; + } +}