open Lexer

type parser_context = {
  seq: (Token.t * Lexer.lexer_context) Seq.t;
  errors: string list;
}

(* The parser is a function that takes a parser_context and returns an option of a tuple of a value and a parser_context.*)
type 'a parser = parser_context -> ('a * parser_context) option

let return (a: 'a) = fun (ctx: parser_context) -> Some (a, ctx)
let stop = fun (_: parser_context) -> None

let fmap (f: 'a -> 'b) (p: 'a parser): 'b parser = fun (ctx: parser_context) ->
  match p ctx with
  | Some (a, ctx') -> Some (f a, ctx')
  | None -> None

let bind (a: 'a parser) (b:'a -> 'b parser) = fun (ctx: parser_context) ->
  let p = a ctx in
  match p with
  | Some (a', ctx') -> b a' ctx'
  | None -> None

let (>>=) = bind
let (let*) = bind

let or_parser (a: 'a parser) (b: 'a parser): 'a parser = fun (ctx: parser_context) ->
  match a ctx with
  | Some _ as res -> res
  | None -> b ctx

let (<|>) = or_parser

let peek_token: Token.t parser = fun (ctx: parser_context) -> 
  Seq.uncons ctx.seq |> Option.map (fun ((t, _),_) -> (t,ctx))

let next_token: Token.t parser = fun (ctx: parser_context) -> 
  Seq.uncons ctx.seq |> Option.map (fun ((t,_), s) -> (t, 
  { ctx with seq = s}
))

let match_token  (tt: Token.token_type) : Token.t parser = 
  let* t = next_token in
  if t.token_type = tt then
    return t
  else
    stop 

let zero_or_one (p: 'a parser): ('a option) parser = fun (ctx) ->
  match p ctx with
    | Some (a, ctx') -> Some (Some a, ctx')
    | None -> Some (None, ctx)

let rec many (p: 'a parser): 'a list parser =
  let* a = zero_or_one p in
  match a with
  | Some a' -> (
    let* as' = many p in
    return (a'::as')
  )
  | None -> return []

let many1 (p: 'a parser): 'a list parser =
  let* a = p in
  let* as' = many p in
  return (a::as')

(* 
BNF:
  let_expr ::= let identifier = expr in expr
  fun_expr ::= fun identifier -> expr
  if_expr ::= if expr then expr else expr
  factor ::= (expr) | identifier | number
  call_expr ::= factor | factor factor
  level1 ::= call_expr | level1 + call_expr | level1 - call_expr
  level2 ::= level2 * level1 | level2 / level1 | level2 % level1 | level1  
  level3 ::= level2 ^ level3 | level2
  expr ::= let_expr | fun_expr | if_expr | level3
*)

type bin_op_type = 
  | Add
  | Sub
  | Mul
  | Div
  | Mod
  | Pow

let token2op (t: Token.token_type): bin_op_type option = 
  match t with
  | Token.Add -> Some Add
  | Token.Sub -> Some Sub
  | Token.Mul -> Some Mul
  | Token.Div -> Some Div
  | Token.Mod -> Some Mod
  | Token.Pow -> Some Pow
  | _ -> None

let op2str (op: bin_op_type): string = 
  match op with
  | Add -> "+"
  | Sub -> "-"
  | Mul -> "*"
  | Div -> "/"
  | Mod -> "%"
  | Pow -> "^"

type mono_op_type = 
  | Neg

type let_expr_tree = Let of string * expr_tree * expr_tree
and fun_expr_tree = Fun of string * expr_tree
and if_expr_tree = If of expr_tree * expr_tree * expr_tree
and call_expr_tree = Call of expr_tree * expr_tree
and expr_tree = 
  | LetExpr of let_expr_tree
  | FunExpr of fun_expr_tree
  | IfExpr of if_expr_tree
  | CallExpr of call_expr_tree
  | BinOpExpr of bin_op_type * expr_tree * expr_tree
  | MonoOpExpr of bin_op_type * expr_tree
  | Identifier of string
  | Number of int

let expr2str (e: expr_tree): string = 
  let tab n = String.make (n * 2) ' ' in
  let rec aux e depth = 
    match e with
    | LetExpr (Let (id, e1, e2)) -> Printf.sprintf "let %s = %s in\n%s%s" id (aux e1 depth) (tab depth) (aux e2 (depth+1))
    | FunExpr (Fun (id, e)) -> Printf.sprintf "fun %s ->\n%s%s" id (tab depth) (aux e (depth+1))
    | IfExpr (If (e1, e2, e3)) -> Printf.sprintf "if %s then\n%s%selse\n%s%s" (aux e1 depth) (tab depth) (aux e2 depth) (tab depth) (aux e3 depth)
    | CallExpr (Call (e1, e2)) -> Printf.sprintf "%s %s" (aux e1 depth) (aux e2 depth)
    | BinOpExpr (op, e1, e2) -> Printf.sprintf "%s %s %s" (aux e1 depth) (op2str op) (aux e2 depth)
    | MonoOpExpr (op, e) -> Printf.sprintf "%s %s" (op2str op) (aux e depth)
    | Identifier id -> id
    | Number n -> string_of_int n in
  aux e 0

let rec parse_let_expr (): let_expr_tree parser = 
  let* _ = match_token ( Token.Let) in
  let* tt = next_token in
  match tt.token_type with
  Token.Identifier(x) ->
      let id = x in
      let* _ = match_token Token.Equal in
      let* e1 = expr() in
      let* _ = match_token (Token.In) in
      let* e2 = expr() in
      return (Let (id, e1, e2))
  | _ -> stop
and parse_fun_expr (): fun_expr_tree parser =
  let* _ = match_token (Token.Fun) in
  let* tt = next_token in
  match tt.token_type with
  Token.Identifier(x) ->
      let id = x in
      let* _ = match_token Token.Arrow in
      let* e = expr() in
      return (Fun (id, e))
  | _ -> stop
and parse_if_expr (): if_expr_tree parser =
  let* _ = match_token (Token.If) in
  let* e1 = expr() in
  let* _ = match_token (Token.Then) in
  let* e2 = expr() in
  let* _ = match_token (Token.Else) in
  let* e3 = expr() in
  return (If (e1, e2, e3))
and parse_factor (): expr_tree parser =
  let* tt = peek_token in
  match tt.token_type with
  | Token.Identifier x ->
    let* _ = next_token in
    return (Identifier x)
  | Token.Digit x ->
    let* _ = next_token in
    return (Number (int_of_string x)) 
  | Token.LParen -> 
    let* _ = match_token Token.LParen in
    let* e = expr() in
    let* _ = match_token Token.RParen in
    return e
  | _ -> stop
and parse_call_expr (): expr_tree parser =
  let* e1 = parse_factor() in
  let rec aux e1 = 
    let* c = peek_token in
    match c.token_type with
    | Token.Identifier _ | Token.Digit _ | Token.LParen -> 
      let* e2 = parse_factor() in
      aux (CallExpr (Call (e1, e2)))
    | _ -> return e1 in
  aux e1
and parse_level1 (): expr_tree parser =
  let* e1 = parse_call_expr() in
  let rec aux e1 = 
    let* c = peek_token in
    let tt = c.token_type in
    match tt with
    | Token.Add | Token.Sub ->
      let* _ = next_token in
      let* e2 = parse_call_expr() in
      let op = match token2op tt with
      | Some x -> x
      | None -> failwith "unreachable" in
      aux (BinOpExpr (op, e1, e2))
    | _ -> return e1 in
  aux e1
and parse_level2 (): expr_tree parser =
  let* e1 = parse_level1() in
  let rec aux e1 = 
    let* c = peek_token in
    match c.token_type with
    | Token.Mul | Token.Div | Token.Mod -> 
      let* _ = next_token in
      let* e2 = parse_level1() in
      let op = match token2op c.token_type with
      | Some x -> x
      | None -> failwith "unreachable" in
      aux (BinOpExpr (op, e1, e2))
    | _ -> return e1 in
  aux e1
and parse_level3 (): expr_tree parser =
  let* e1 = parse_level2() in
  let rec aux e1 = 
    let* c = peek_token in
    match c.token_type with
    | Token.Pow -> 
      let* _ = next_token in
      let* e2 = parse_level3() in
      let op = match token2op c.token_type with
      | Some x -> x
      | None -> failwith "unreachable" in
      aux (BinOpExpr (op, e1, e2))
    | _ -> return e1 in
  aux e1
and expr (): expr_tree parser =
  let* e = (parse_let_expr() |> fmap (fun x -> LetExpr x)) <|> 
    (parse_fun_expr() |> fmap (fun x -> FunExpr x)) <|> 
    (parse_if_expr() |> fmap (fun x -> IfExpr x)) <|> parse_level3() in
  return e

let get_expr_tree_from_tokens (tokens: (Token.t * Lexer.lexer_context) Seq.t): expr_tree option =
  let ntokens = Seq.filter (fun ((token,_): Token.t * Lexer.lexer_context) -> 
    match token.Token.token_type with 
    | Token.Comment(_) -> false 
    | _ -> true
  ) tokens in
  let ctx = { seq = ntokens; errors = [] } in
  match expr() ctx with
  | Some (e, _) -> Some e
  | None -> None

let%test "test get_expr_tree_from_tokens 1" =
  let tokens = Lexer.lex_tokens_seq "let x = 1 in\n  x" in
  match get_expr_tree_from_tokens tokens with
  | Some e -> expr2str e = "let x = 1 in\n  x"
  | None -> false