type type_t = (* e.g. list, option, etc. *) | TypeConstructor of string * type_t list (* e.g. int, string, etc. *) | TypeVariable of string (* e.g. int -> int, string -> string, etc. This is a function type, where the first type is the input and the second type is the output. The arrow can be right associative, so (a -> b -> c) is equivalent to (a -> (b -> c)). *) | TypeArrow of type_t * type_t let type_t2str (ty: type_t) = let rec aux ty = match ty with | TypeConstructor (name, args) -> name ^ if List.is_empty args then "" else " of (" ^ String.concat ", " (List.map aux args) ^ ")" | TypeVariable x -> x | TypeArrow (a, b) -> "(" ^ aux a ^ " -> " ^ aux b ^ ")" in aux ty type substitution_t = (string * type_t) list (* check if a variable occurs in a type_t *) let rec occurs (t : type_t) (id: string) : bool = match t with | TypeVariable x -> x = id | TypeConstructor (_, lst) -> List.exists (fun x -> (occurs x id)) lst | TypeArrow (a,b) -> (occurs a id) || (occurs b id) let rec subst (s: substitution_t) (t: type_t) : type_t = match t with | TypeVariable x -> (try List.assoc x s with Not_found -> t) | TypeConstructor (name, args) -> TypeConstructor (name, List.map (subst s) args) | TypeArrow (a, b) -> TypeArrow (subst s a, subst s b) let compose (s1: substitution_t) (s2: substitution_t) : substitution_t = (* apply s1 to s2 *) let s2' = List.map (fun (x, t) -> (x, subst s1 t)) s2 in (* remove duplicates *) let s1' = List.filter (fun (x, _) -> not (List.mem_assoc x s2')) s1 in s1' @ s2' let compose_list (slist : substitution_t list) = List.fold_left compose [] slist let rec unify (t1: type_t) (t2: type_t) : substitution_t = match t1, t2 with | TypeVariable x, _ when not (occurs t2 x) -> [(x, t2)] | _, TypeVariable x when not (occurs t1 x) -> [(x, t1)] | TypeVariable x, TypeVariable y when x = y -> [] | TypeConstructor (n1, args1), TypeConstructor (n2, args2) when n1 = n2 -> List.fold_left2 (fun s a1 a2 -> compose s (unify a1 a2)) [] args1 args2 | TypeArrow (a1, b1), TypeArrow (a2, b2) -> let s1 = unify a1 a2 in let s2 = unify (subst s1 b1) (subst s1 b2) in compose s1 s2 | _ -> raise (Failure ("Type mismatch: " ^ type_t2str t1 ^ " vs " ^ type_t2str t2)) let assert_type (t1: type_t) (t2: type_t) : unit = ignore @@ unify t1 t2 (* Test cases for the unify function *) let%test "unify1" = let t1 = TypeConstructor ("int", []) in let t2 = TypeVariable "a" in let s = unify t1 t2 in (* Printf.printf "%s\n" (type_t2str (subst s t1)); *) subst s t1 = subst s t2 let%test "unify2" = let t1 = TypeArrow (TypeVariable "a", TypeVariable "b") in let t2 = TypeArrow (TypeVariable "c", TypeVariable "d") in let s = unify t1 t2 in (* Printf.printf "%s\n" (s |> List.map (fun (x, t) -> x ^ " => " ^ type_t2str t) |> String.concat ", "); Printf.printf "%s\n" (type_t2str (subst s t1)); *) subst s t1 = subst s t2 let%test "unify3" = let t1 = TypeArrow (TypeVariable "a", TypeVariable "b") in let t2 = TypeArrow (TypeVariable "a", TypeVariable "c") in let s = unify t1 t2 in (* Printf.printf "%s\n" (s |> List.map (fun (x, t) -> x ^ " => " ^ type_t2str t) |> String.concat ", "); Printf.printf "%s\n" (type_t2str (subst s t1)); *) subst s t1 = subst s t2 let%test "unify4" = let t1 = TypeConstructor ("list", [TypeVariable "a"]) in let t2 = TypeConstructor ("list", [TypeVariable "b"]) in let s = unify t1 t2 in (* Printf.printf "%s\n" (s |> List.map (fun (x, t) -> x ^ " => " ^ type_t2str t) |> String.concat ", "); Printf.printf "%s\n" (type_t2str (subst s t1)); *) subst s t1 = subst s t2 type env_t = { parent : env_t option; bindings : (string, type_t) Hashtbl.t; } let rec lookup (env: env_t) (name: string) = match Hashtbl.find_opt env.bindings name with | Some ty -> ty | None -> match env.parent with | Some parent -> lookup parent name | None -> raise (Failure ("Unbound variable: " ^ name)) type context_t = { env : env_t; constraints : (string * type_t) list; } let rec type_tree2type_t (tt: Parser.type_tree) = match tt with | TypeIdentifier(s) -> TypeConstructor(s, []) | TypeArrow(arg, ret) -> TypeArrow(type_tree2type_t(arg), type_tree2type_t(ret)) let type_variable_counter = ref 0 let new_type_variable () = let i = !type_variable_counter in incr type_variable_counter; TypeVariable ("'t" ^ string_of_int i) let rec check_expr (ctx: context_t) (exp: Parser.expr_tree) : (type_t * substitution_t) = match exp with | Parser.LetExpr (l) -> let ty = check_let_expr ctx l in ty | Parser.FunExpr (ftree) -> let ty = check_fun_expr ctx ftree in ty | Parser.IfExpr (Parser.If (cond_expr, then_expr, else_expr)) -> let (cond_ty, cond_s) = check_expr ctx cond_expr in let (then_ty, then_s) = check_expr ctx then_expr in let (else_ty, else_s) = check_expr ctx else_expr in let s1 = [cond_s; then_s; else_s] |> compose_list in (* currently, boolean type doesn't exist. *) let s2 = unify cond_ty (TypeConstructor ("int", [])) in let s3 = compose s1 s2 in (* unify the types of then and else branches *) let s4 = (unify (subst s3 then_ty) (subst s3 else_ty)) in let s5 = compose s4 s3 in (* return the type of then and else branches *) (subst s5 then_ty, s5) | Parser.BinOpExpr (_op, left_expr, right_expr) -> let (left_ty, ls) = check_expr ctx left_expr in let (right_ty, rs) = check_expr ctx right_expr in (* unify the types of the left and right expressions *) (* but currently, only int is supported. *) let s = compose ls rs in let s2 = unify (subst s left_ty) (TypeConstructor ("int", [])) in let s3 = compose s s2 in (* unify the types of the left and right expressions *) let s4 = unify (subst s3 right_ty) (TypeConstructor ("int", [])) in let s5 = compose s3 s4 in (* return int type *) (TypeConstructor ("int", []), s5) | Parser.MonoOpExpr (_op, expr) -> let (ty, s1) = check_expr ctx expr in let s2 = unify (subst s1 ty) (TypeConstructor ("int", [])) in let s3 = compose s1 s2 in (* return int type *) (TypeConstructor ("int", []), s3) | Parser.CallExpr (Parser.Call (func_expr, arg_expr)) -> let (func_ty, func_s) = check_expr ctx func_expr in (* s1을 context에 적용하는 로직 추가 필요 *) let (arg_ty, arg_s) = check_expr ctx arg_expr in let s1 = compose func_s arg_s in let ret_ty_var = new_type_variable () in (* 반환 타입을 위한 새 타입 변수 *) let expected_func_ty = TypeArrow (subst s1 arg_ty, ret_ty_var) in let s2 = unify (subst s1 func_ty) expected_func_ty in (* 실제 함수 타입과 기대 타입 통일 *) let s3 = compose s1 s2 in (subst s3 ret_ty_var, s3) (* 추론된 반환 타입과 최종 substitution 반환 *) | Parser.Identifier(name) -> let ty = lookup ctx.env name in (ty, []) | Parser.Number(_n) -> (TypeConstructor ("int", []), []) and check_let_expr (ctx: context_t) (l: Parser.let_expr_tree) : type_t * substitution_t = let (vtt, s1) = check_expr ctx l.value_expr in let (vt, s2) = match l.type_declare with | Some(td) -> let tt = type_tree2type_t td in let s_unify = unify (subst s1 vtt) tt in (* s1 적용 후 통일 *) let s_compose = compose s1 s_unify in (subst s_compose tt, s_compose) | None -> (subst s1 vtt, s1) (* s1 적용 *) in (* s2는 값 검사와 타입 선언 통일로부터 얻은 최종 substitution *) let final_vt = vt in (* s2를 context에 적용하는 로직 추가 필요 *) let new_env = { parent = Some ctx.env; (* 수정된 context의 env 사용 필요 *) bindings = Hashtbl.copy ctx.env.bindings; (* 수정된 context의 bindings 사용 필요 *) } in Hashtbl.replace new_env.bindings l.name final_vt; let new_ctx = { ctx with env = new_env } in (* 수정된 context 사용 필요 *) let (in_type, s3) = check_expr new_ctx l.in_expr in let s4 = compose s2 s3 in (in_type, s4) and check_fun_expr (ctx: context_t) (ftree: Parser.fun_expr_tree) : type_t * substitution_t = (* Create a new environment for the function body *) let new_env = { parent = Some ctx.env; bindings = Hashtbl.create 10; (* Start with an empty table for the new scope *) } in (* Determine the argument type *) let arg_type = match ftree.type_declare with | Some(td) -> type_tree2type_t td | None -> new_type_variable () (* Infer the type if not declared *) in (* Add the argument binding to the new environment *) Hashtbl.add new_env.bindings ftree.name arg_type; let new_ctx = { ctx with env = new_env } in (* Check the type of the function body in the new context *) (* 이 과정에서 substitution이 발생할 수 있음 (예: x가 int를 요구하는 곳에 사용될 때) *) let (body_type, s_body) = check_expr new_ctx ftree.body_expr in (* 함수 본문 검사에서 얻은 substitution을 인자 타입에 적용 *) let final_arg_type = subst s_body arg_type in (* 최종 함수 타입과 발생한 substitution 반환 *) (TypeArrow (final_arg_type, body_type), s_body) let infer_type (exp: Parser.expr_tree) : type_t = let initial_env = { parent = None; bindings = Hashtbl.create 10 } in let initial_ctx = { env = initial_env; constraints = [] } in let (inferred_type, final_substitution) = check_expr initial_ctx exp in (* 최종적으로 얻은 substitution을 전체 타입에 적용 *) subst final_substitution inferred_type let%test "test_infer_type1" = let exp = Parser.BinOpExpr (Parser.Add, Parser.Number(1), Parser.Number(2)) in let inferred_type = infer_type exp in inferred_type = TypeConstructor ("int", []) let%test "infer_type_fun_add" = let input_str = "fun x -> x + 1" in let tokens = Lexer.lex_tokens_seq input_str in let expr_opt = Parser.get_expr_tree_from_tokens tokens in match expr_opt with | Some expr -> let inferred = infer_type expr in let expected = TypeArrow (TypeConstructor ("int", []), TypeConstructor ("int", [])) in (* Printf.printf "Inferred: %s\n" (type_t2str inferred); Printf.printf "Expected: %s\n" (type_t2str expected); *) inferred = expected | None -> Printf.printf "Failed to parse expression: %s\n" input_str; false let%test "infer_type_let_fun" = let input_str = "let f = fun x -> x + 1 in f 10" in let tokens = Lexer.lex_tokens_seq input_str in let expr_opt = Parser.get_expr_tree_from_tokens tokens in match expr_opt with | Some expr -> let inferred = infer_type expr in let expected = TypeConstructor ("int", []) in (* Printf.printf "Inferred: %s\n" (type_t2str inferred); Printf.printf "Expected: %s\n" (type_t2str expected); *) inferred = expected | None -> Printf.printf "Failed to parse expression: %s\n" input_str; false let%test "infer_type_if" = let input_str = "if 1 then 2 else 3" in let tokens = Lexer.lex_tokens_seq input_str in let expr_opt = Parser.get_expr_tree_from_tokens tokens in match expr_opt with | Some expr -> let inferred = infer_type expr in let expected = TypeConstructor ("int", []) in (* Printf.printf "Inferred: %s\n" (type_t2str inferred); Printf.printf "Expected: %s\n" (type_t2str expected); *) inferred = expected | None -> Printf.printf "Failed to parse expression: %s\n" input_str; false