From d801a1b34f946cd8be9d6a405f78b13cc8aaf07a Mon Sep 17 00:00:00 2001 From: monoid Date: Sat, 3 May 2025 19:42:58 +0900 Subject: [PATCH] feat: add type checking and evaluation with type inference for expressions --- bin/main.ml | 7 +- lib/eval.ml | 16 ++ lib/typecheck.ml | 468 +++++++++++++++++++++++++++-------------------- 3 files changed, 290 insertions(+), 201 deletions(-) diff --git a/bin/main.ml b/bin/main.ml index c15293e..8c0c806 100644 --- a/bin/main.ml +++ b/bin/main.ml @@ -4,10 +4,9 @@ open Calc;; let main () = let input = Sys.argv.(1) in let _ = Printf.printf "input: %s\n" input in - let result = Eval.eval_str input in - match result with - | Eval.Int n -> Printf.printf "%d\n" n - | _ -> failwith "Type error" + let (result, ty) = Eval.eval_str_with_typecheck input in + Printf.printf "type: %s\n" (Typecheck.type_t2str ty); + Printf.printf "result: %s\n" (result |> Eval.string_of_value_type); ;; main ();; \ No newline at end of file diff --git a/lib/eval.ml b/lib/eval.ml index 8ed2751..25061a1 100644 --- a/lib/eval.ml +++ b/lib/eval.ml @@ -13,6 +13,11 @@ and function_type = { scope: scope; } +let string_of_value_type (v: value_type): string = + match v with + | Int n -> string_of_int n + | Fun f -> Printf.sprintf " %s >" f.argname (f.body |> Parser.expr2str) + let rec eval_expr (scope: scope) (expr: Parser.expr_tree): value_type = match expr with | Parser.LetExpr (l) -> @@ -104,6 +109,17 @@ let eval_str (str: string): value_type = | Some e -> eval_expr { parent = None; bindings = VariableBindingMap.empty } e | None -> failwith "Parse error" +let eval_str_with_typecheck (str: string): value_type * Typecheck.type_t = + let tokens = Lexer.lex_tokens_seq str in + let expr = Parser.get_expr_tree_from_tokens tokens in + match expr with + | Some e -> + let ty = Typecheck.infer_type e in + let v = eval_expr { parent = None; bindings = VariableBindingMap.empty } e in + (v, ty) + | None -> failwith "Parse error" + +(* Test cases *) let%test "test eval_str 1" = let result = eval_str "let x = 1 in x" in diff --git a/lib/typecheck.ml b/lib/typecheck.ml index f45333b..6bbdb37 100644 --- a/lib/typecheck.ml +++ b/lib/typecheck.ml @@ -1,215 +1,289 @@ -type type_v = - | Int - | Fun of { - arg: type_v; - ret: type_v; - } - | Generic of string - | Universal - | Nothing +type type_t = +(* e.g. list, option, etc. *) + | TypeConstructor of string * type_t list +(* e.g. int, string, etc. *) + | TypeVariable of string +(* e.g. int -> int, string -> string, etc. + This is a function type, where the first type is the input and the second type is the output. + The arrow can be right associative, so (a -> b -> c) is equivalent to (a -> (b -> c)). +*) + | TypeArrow of type_t * type_t -type type_scope = { - parent: type_scope option; - bindings: (string, type_v) Hashtbl.t; - generics_count: int ref; -} +let type_t2str (ty: type_t) = + let rec aux ty = + match ty with + | TypeConstructor (name, args) -> + name ^ if List.is_empty args then "" else + " of (" ^ String.concat ", " (List.map aux args) ^ ")" + | TypeVariable x -> x + | TypeArrow (a, b) -> "(" ^ aux a ^ " -> " ^ aux b ^ ")" + in + aux ty -let make_type_scope (parent: type_scope): type_scope = { - parent = Some (parent); - bindings = Hashtbl.create 10; - generics_count = parent.generics_count; -} +type substitution_t = (string * type_t) list -let make_top_type_scope (): type_scope = { - parent = None; - bindings = Hashtbl.create 10; - generics_count = ref 0; -} - -let rec typetree2type_v (t: Parser.type_tree): type_v = +(* check if a variable occurs in a type_t *) +let rec occurs (t : type_t) (id: string) : bool = match t with - | Parser.TypeIdentifier (x) -> if x = "int" then Int else - failwith "not implemented (type alias is not supported yet)" - | Parser.TypeArrow (arg, ret) -> Fun { arg = typetree2type_v arg; ret = typetree2type_v ret } + | TypeVariable x -> x = id + | TypeConstructor (_, lst) -> List.exists (fun x -> (occurs x id)) lst + | TypeArrow (a,b) -> (occurs a id) || (occurs b id) -let rec type_v2str (t: type_v): string = +let rec subst (s: substitution_t) (t: type_t) : type_t = match t with - | Int -> "int" - | Fun { arg = arg; ret = ret } -> Printf.sprintf "(%s -> %s)" (type_v2str arg) (type_v2str ret) - | Generic s -> "'" ^ s - | Universal -> "universal" - | Nothing -> "nothing" + | TypeVariable x -> (try List.assoc x s with Not_found -> t) + | TypeConstructor (name, args) -> TypeConstructor (name, List.map (subst s) args) + | TypeArrow (a, b) -> TypeArrow (subst s a, subst s b) -(* meet *) -let rec intersect_type_v (a: type_v) (b: type_v): type_v = - match a, b with - | Universal, _ -> b - | _, Universal -> a - | Int, Int -> Int - | Fun { arg = arg1; ret = ret1 }, Fun { arg = arg2; ret = ret2 } -> - (* contravariance *) - let arg = intersect_type_v arg1 arg2 in - let ret = intersect_type_v ret1 ret2 in - Fun { arg = arg; ret = ret } - (* // TODO: fix it *) - | Generic s1, Generic s2 when s1 = s2 -> Generic s1 - | Generic _, _ -> b - | _ -> Nothing -(* join *) -(* and union_type_v (a: type_v) (b: type_v): type_v = - match a, b with - | Universal, _ -> Universal - | _, Universal -> Universal - | Int, Int -> Int - | Fun { arg = arg1; ret = ret1 }, Fun { arg = arg2; ret = ret2 } -> +let compose (s1: substitution_t) (s2: substitution_t) : substitution_t = + (* apply s1 to s2 *) + let s2' = List.map (fun (x, t) -> (x, subst s1 t)) s2 in + (* remove duplicates *) + let s1' = List.filter (fun (x, _) -> not (List.mem_assoc x s2')) s1 in + s1' @ s2' - | Generic s1, Generic s2 when s1 = s2 -> Generic s1 - | Generic _, _ -> b - | _ -> Nothing *) +let compose_list (slist : substitution_t list) = + List.fold_left compose [] slist -(* it assumes that there is already variable binding. *) -let find_type_v_opt (scope: type_scope) (name: string): type_v option = - let rec find_binding scope = - match scope with - | None -> None - | Some s -> - match Hashtbl.find_opt s.bindings name with - | Some v -> Some v - | None -> find_binding s.parent in - find_binding (Some scope) +let rec unify (t1: type_t) (t2: type_t) : substitution_t = + match t1, t2 with + | TypeVariable x, _ when not (occurs t2 x) -> [(x, t2)] + | _, TypeVariable x when not (occurs t1 x) -> [(x, t1)] + | TypeVariable x, TypeVariable y when x = y -> [] + | TypeConstructor (n1, args1), TypeConstructor (n2, args2) when n1 = n2 -> + List.fold_left2 (fun s a1 a2 -> compose s (unify a1 a2)) [] args1 args2 + | TypeArrow (a1, b1), TypeArrow (a2, b2) -> + let s1 = unify a1 a2 in + let s2 = unify (subst s1 b1) (subst s1 b2) in + compose s1 s2 + | _ -> raise (Failure ("Type mismatch: " ^ type_t2str t1 ^ " vs " ^ type_t2str t2)) -(* it assumes that there is already variable binding. *) -let assert_and_get_type_v (scope: type_scope) (name: string) (expected: type_v) = - let rec assert_binding scope = - match scope with - | None -> failwith "Unbound variable" - | Some s -> - match Hashtbl.find_opt s.bindings name with - | Some v -> - let subtype = intersect_type_v v expected in - if subtype = Nothing then failwith "Type error" - else Hashtbl.replace s.bindings name subtype; - subtype - | None -> assert_binding s.parent in - assert_binding (Some scope) - -let gen_generic_free_name (scope: type_scope): string = - let generics_count = !(scope.generics_count) in - let name = Printf.sprintf "%d" generics_count in - scope.generics_count := generics_count + 1; - name - -let replace_generic_with (t: type_v) (from: string) (to_: type_v): type_v = - let rec replace t = - match t with - | Int -> Int - | Fun { arg = arg; ret = ret } -> Fun { arg = replace arg; ret = replace ret } - | Generic s when s = from -> to_ - | Generic s -> Generic s - | Universal -> Universal - | Nothing -> Nothing in - replace t - -let rec typecheck_expr (scope: type_scope) (expr: Parser.expr_tree) (required_type: type_v): type_v = - let actual_type = match expr with - | Parser.LetExpr (l) -> - typecheck_let_expr scope required_type l - | Parser.FunExpr (ftree) -> - typecheck_fun_expr scope required_type ftree - | Parser.IfExpr (Parser.If (cond_expr, then_expr, else_expr)) -> - typecheck_if_expr scope required_type cond_expr then_expr else_expr - | Parser.BinOpExpr (op, left_expr, right_expr) -> - typecheck_bin_op_expr scope required_type op left_expr right_expr - | Parser.MonoOpExpr (_op, _expr) -> - failwith "Not implemented" - | Parser.CallExpr (Parser.Call (func_expr, arg_expr)) -> - typecheck_call_expr scope required_type func_expr arg_expr - | Parser.Identifier(name) -> assert_and_get_type_v scope name required_type - | Parser.Number(_n) -> Int - in - let subtype = intersect_type_v required_type actual_type in - if subtype = Nothing then - failwith (Printf.sprintf "Type error: expect %s but actual %s" - (type_v2str required_type) (type_v2str actual_type) - ) - else subtype -and typecheck_let_expr (scope: type_scope) (required_type: type_v) ({ - name = name; - value_expr = value_expr; - in_expr = in_expr; - type_declare = type_decl; -}: Parser.let_expr_tree): type_v = - let value_reqired_type = type_decl |> Option.map typetree2type_v |> Option.value ~default: Universal in - let value_type = typecheck_expr scope value_expr value_reqired_type in - let new_scope = make_type_scope scope in - Hashtbl.add new_scope.bindings name value_type; - typecheck_expr new_scope in_expr required_type -and typecheck_fun_expr (scope: type_scope) (_required_type: type_v) ({ - name = argname; - body_expr = body_expr; - type_declare = type_decl; -}: Parser.fun_expr_tree): type_v = - let default_type = Generic (gen_generic_free_name scope) in - let arg_type = type_decl |> Option.map typetree2type_v |> Option.value ~default: default_type in - let new_scope = make_type_scope scope in - Hashtbl.add new_scope.bindings argname arg_type; - (* unreachable because *) - let ret_type = typecheck_expr new_scope body_expr Universal in - let arg_type = Hashtbl.find new_scope.bindings argname in - Printf.printf "arg: %s, ret: %s\n" (type_v2str arg_type) (type_v2str ret_type); +let assert_type (t1: type_t) (t2: type_t) : unit = + ignore @@ unify t1 t2 - Fun { arg = arg_type; ret = ret_type } -and typecheck_if_expr (scope: type_scope) (required_type: type_v) - (cond_expr: Parser.expr_tree) (then_expr: Parser.expr_tree) (else_expr: Parser.expr_tree): type_v = - let _ = typecheck_expr scope cond_expr Int in - let then_type = typecheck_expr scope then_expr required_type in - let else_type = typecheck_expr scope else_expr required_type in - intersect_type_v then_type else_type -and typecheck_bin_op_expr (scope: type_scope) (_required_type: type_v) - (_op: Parser.bin_op_type) (left_expr: Parser.expr_tree) (right_expr: Parser.expr_tree): type_v = - (* default int *) - let _ = typecheck_expr scope left_expr Int in - let _ = typecheck_expr scope right_expr Int in - Int -and typecheck_call_expr (scope: type_scope) (_required_type: type_v) - (func_expr: Parser.expr_tree) (arg_expr: Parser.expr_tree): type_v = - let func_type = typecheck_expr scope func_expr Universal in - Printf.printf "func_type: %s\n" (type_v2str func_type); + (* Test cases for the unify function *) +let%test "unify1" = + let t1 = TypeConstructor ("int", []) in + let t2 = TypeVariable "a" in + let s = unify t1 t2 in + (* Printf.printf "%s\n" (type_t2str (subst s t1)); *) + subst s t1 = subst s t2 - match func_type with - | Fun { arg = arg_type; ret = ret_type } -> - let mono_arg_type = typecheck_expr scope arg_expr arg_type in - Printf.printf "arg_type: %s\n" (type_v2str mono_arg_type); - begin match arg_type with - | Generic s -> - (* instance *) - let new_ret_type = replace_generic_with ret_type s mono_arg_type in - Printf.printf "new_ret_type: %s\n" (type_v2str new_ret_type); - new_ret_type - | _ -> ret_type - end - | _ -> failwith "Type error" +let%test "unify2" = + let t1 = TypeArrow (TypeVariable "a", TypeVariable "b") in + let t2 = TypeArrow (TypeVariable "c", TypeVariable "d") in + let s = unify t1 t2 in + (* Printf.printf "%s\n" (s |> List.map (fun (x, t) -> x ^ " => " ^ type_t2str t) |> String.concat ", "); + Printf.printf "%s\n" (type_t2str (subst s t1)); *) + subst s t1 = subst s t2 -let typecheck (expr: Parser.expr_tree): type_v = - typecheck_expr (make_top_type_scope()) expr Universal +let%test "unify3" = + let t1 = TypeArrow (TypeVariable "a", TypeVariable "b") in + let t2 = TypeArrow (TypeVariable "a", TypeVariable "c") in + let s = unify t1 t2 in + (* Printf.printf "%s\n" (s |> List.map (fun (x, t) -> x ^ " => " ^ type_t2str t) |> String.concat ", "); + Printf.printf "%s\n" (type_t2str (subst s t1)); *) + subst s t1 = subst s t2 -let typecheck_result (expr: Parser.expr_tree): (type_v, exn) result = - try - let t = typecheck expr in - Result.Ok (t) - with e -> Result.Error e +let%test "unify4" = + let t1 = TypeConstructor ("list", [TypeVariable "a"]) in + let t2 = TypeConstructor ("list", [TypeVariable "b"]) in + let s = unify t1 t2 in + (* Printf.printf "%s\n" (s |> List.map (fun (x, t) -> x ^ " => " ^ type_t2str t) |> String.concat ", "); + Printf.printf "%s\n" (type_t2str (subst s t1)); *) + subst s t1 = subst s t2 -let test_typecheck (content:string) = - let tokens = Lexer.lex_tokens_seq content in - let expr = Parser.get_expr_tree_from_tokens tokens in - match expr with - | Some e -> typecheck_result e - | None -> Result.Error (Failure "parse error") +type env_t = { + parent : env_t option; + bindings : (string, type_t) Hashtbl.t; +} -let%test "typecheck 1" = - let expr = "let x = fun y -> y in x 1" in - match test_typecheck expr with - | Result.Ok (t) -> Printf.printf "%s\n" (type_v2str t); t = Int - | Result.Error _ -> Printf.printf "error\n"; false \ No newline at end of file +let rec lookup (env: env_t) (name: string) = + match Hashtbl.find_opt env.bindings name with + | Some ty -> ty + | None -> match env.parent with + | Some parent -> lookup parent name + | None -> raise (Failure ("Unbound variable: " ^ name)) + +type context_t = { + env : env_t; + constraints : (string * type_t) list; +} + +let rec type_tree2type_t (tt: Parser.type_tree) = + match tt with + | TypeIdentifier(s) -> TypeConstructor(s, []) + | TypeArrow(arg, ret) -> TypeArrow(type_tree2type_t(arg), type_tree2type_t(ret)) + +let type_variable_counter = ref 0 +let new_type_variable () = + let i = !type_variable_counter in + incr type_variable_counter; + TypeVariable ("'t" ^ string_of_int i) + +let rec check_expr (ctx: context_t) (exp: Parser.expr_tree) : (type_t * substitution_t) = + match exp with + | Parser.LetExpr (l) -> + let ty = check_let_expr ctx l in + ty + | Parser.FunExpr (ftree) -> + let ty = check_fun_expr ctx ftree in + ty + | Parser.IfExpr (Parser.If (cond_expr, then_expr, else_expr)) -> + let (cond_ty, cond_s) = check_expr ctx cond_expr in + let (then_ty, then_s) = check_expr ctx then_expr in + let (else_ty, else_s) = check_expr ctx else_expr in + let s1 = [cond_s; then_s; else_s] |> compose_list in + (* currently, boolean type doesn't exist. *) + let s2 = unify cond_ty (TypeConstructor ("int", [])) in + let s3 = compose s1 s2 in + (* unify the types of then and else branches *) + let s4 = (unify (subst s3 then_ty) (subst s3 else_ty)) in + let s5 = compose s4 s3 in + (* return the type of then and else branches *) + (subst s5 then_ty, s5) + | Parser.BinOpExpr (_op, left_expr, right_expr) -> + let (left_ty, ls) = check_expr ctx left_expr in + let (right_ty, rs) = check_expr ctx right_expr in + (* unify the types of the left and right expressions *) + (* but currently, only int is supported. *) + let s = compose ls rs in + let s2 = unify (subst s left_ty) (TypeConstructor ("int", [])) in + let s3 = compose s s2 in + (* unify the types of the left and right expressions *) + let s4 = unify (subst s3 right_ty) (TypeConstructor ("int", [])) in + let s5 = compose s3 s4 in + (* return int type *) + (TypeConstructor ("int", []), s5) + | Parser.MonoOpExpr (_op, expr) -> + let (ty, s1) = check_expr ctx expr in + let s2 = unify (subst s1 ty) (TypeConstructor ("int", [])) in + let s3 = compose s1 s2 in + (* return int type *) + (TypeConstructor ("int", []), s3) + | Parser.CallExpr (Parser.Call (func_expr, arg_expr)) -> + let (func_ty, func_s) = check_expr ctx func_expr in + (* s1을 context에 적용하는 로직 추가 필요 *) + let (arg_ty, arg_s) = check_expr ctx arg_expr in + let s1 = compose func_s arg_s in + let ret_ty_var = new_type_variable () in (* 반환 타입을 위한 새 타입 변수 *) + let expected_func_ty = TypeArrow (subst s1 arg_ty, ret_ty_var) in + let s2 = unify (subst s1 func_ty) expected_func_ty in (* 실제 함수 타입과 기대 타입 통일 *) + let s3 = compose s1 s2 in + (subst s3 ret_ty_var, s3) (* 추론된 반환 타입과 최종 substitution 반환 *) + | Parser.Identifier(name) -> + let ty = lookup ctx.env name in + (ty, []) + | Parser.Number(_n) -> (TypeConstructor ("int", []), []) + +and check_let_expr (ctx: context_t) (l: Parser.let_expr_tree) : type_t * substitution_t = + let (vtt, s1) = check_expr ctx l.value_expr in + let (vt, s2) = match l.type_declare with + | Some(td) -> + let tt = type_tree2type_t td in + let s_unify = unify (subst s1 vtt) tt in (* s1 적용 후 통일 *) + let s_compose = compose s1 s_unify in + (subst s_compose tt, s_compose) + | None -> (subst s1 vtt, s1) (* s1 적용 *) + in + (* s2는 값 검사와 타입 선언 통일로부터 얻은 최종 substitution *) + let final_vt = vt in + + (* s2를 context에 적용하는 로직 추가 필요 *) + let new_env = { + parent = Some ctx.env; (* 수정된 context의 env 사용 필요 *) + bindings = Hashtbl.copy ctx.env.bindings; (* 수정된 context의 bindings 사용 필요 *) + } in + Hashtbl.replace new_env.bindings l.name final_vt; + let new_ctx = { ctx with env = new_env } in (* 수정된 context 사용 필요 *) + + let (in_type, s3) = check_expr new_ctx l.in_expr in + let s4 = compose s2 s3 in + (in_type, s4) + +and check_fun_expr (ctx: context_t) (ftree: Parser.fun_expr_tree) : type_t * substitution_t = + (* Create a new environment for the function body *) + let new_env = { + parent = Some ctx.env; + bindings = Hashtbl.create 10; (* Start with an empty table for the new scope *) + } in + + (* Determine the argument type *) + let arg_type = match ftree.type_declare with + | Some(td) -> type_tree2type_t td + | None -> new_type_variable () (* Infer the type if not declared *) + in + + (* Add the argument binding to the new environment *) + Hashtbl.add new_env.bindings ftree.name arg_type; + let new_ctx = { ctx with env = new_env } in + + (* Check the type of the function body in the new context *) + (* 이 과정에서 substitution이 발생할 수 있음 (예: x가 int를 요구하는 곳에 사용될 때) *) + let (body_type, s_body) = check_expr new_ctx ftree.body_expr in + + (* 함수 본문 검사에서 얻은 substitution을 인자 타입에 적용 *) + let final_arg_type = subst s_body arg_type in + + (* 최종 함수 타입과 발생한 substitution 반환 *) + (TypeArrow (final_arg_type, body_type), s_body) + +let infer_type (exp: Parser.expr_tree) : type_t = + let initial_env = { parent = None; bindings = Hashtbl.create 10 } in + let initial_ctx = { env = initial_env; constraints = [] } in + let (inferred_type, final_substitution) = check_expr initial_ctx exp in + (* 최종적으로 얻은 substitution을 전체 타입에 적용 *) + subst final_substitution inferred_type + +let%test "test_infer_type1" = + let exp = Parser.BinOpExpr (Parser.Add, Parser.Number(1), Parser.Number(2)) in + let inferred_type = infer_type exp in + inferred_type = TypeConstructor ("int", []) + + + let%test "infer_type_fun_add" = + let input_str = "fun x -> x + 1" in + let tokens = Lexer.lex_tokens_seq input_str in + let expr_opt = Parser.get_expr_tree_from_tokens tokens in + match expr_opt with + | Some expr -> + let inferred = infer_type expr in + let expected = TypeArrow (TypeConstructor ("int", []), TypeConstructor ("int", [])) in + (* Printf.printf "Inferred: %s\n" (type_t2str inferred); + Printf.printf "Expected: %s\n" (type_t2str expected); *) + inferred = expected + | None -> + Printf.printf "Failed to parse expression: %s\n" input_str; + false + +let%test "infer_type_let_fun" = + let input_str = "let f = fun x -> x + 1 in f 10" in + let tokens = Lexer.lex_tokens_seq input_str in + let expr_opt = Parser.get_expr_tree_from_tokens tokens in + match expr_opt with + | Some expr -> + let inferred = infer_type expr in + let expected = TypeConstructor ("int", []) in + (* Printf.printf "Inferred: %s\n" (type_t2str inferred); + Printf.printf "Expected: %s\n" (type_t2str expected); *) + inferred = expected + | None -> + Printf.printf "Failed to parse expression: %s\n" input_str; + false + +let%test "infer_type_if" = + let input_str = "if 1 then 2 else 3" in + let tokens = Lexer.lex_tokens_seq input_str in + let expr_opt = Parser.get_expr_tree_from_tokens tokens in + match expr_opt with + | Some expr -> + let inferred = infer_type expr in + let expected = TypeConstructor ("int", []) in + (* Printf.printf "Inferred: %s\n" (type_t2str inferred); + Printf.printf "Expected: %s\n" (type_t2str expected); *) + inferred = expected + | None -> + Printf.printf "Failed to parse expression: %s\n" input_str; + false \ No newline at end of file