289 lines
No EOL
11 KiB
OCaml
289 lines
No EOL
11 KiB
OCaml
|
|
type type_t =
|
|
(* e.g. list, option, etc. *)
|
|
| TypeConstructor of string * type_t list
|
|
(* e.g. int, string, etc. *)
|
|
| TypeVariable of string
|
|
(* e.g. int -> int, string -> string, etc.
|
|
This is a function type, where the first type is the input and the second type is the output.
|
|
The arrow can be right associative, so (a -> b -> c) is equivalent to (a -> (b -> c)).
|
|
*)
|
|
| TypeArrow of type_t * type_t
|
|
|
|
let type_t2str (ty: type_t) =
|
|
let rec aux ty =
|
|
match ty with
|
|
| TypeConstructor (name, args) ->
|
|
name ^ if List.is_empty args then "" else
|
|
" of (" ^ String.concat ", " (List.map aux args) ^ ")"
|
|
| TypeVariable x -> x
|
|
| TypeArrow (a, b) -> "(" ^ aux a ^ " -> " ^ aux b ^ ")"
|
|
in
|
|
aux ty
|
|
|
|
type substitution_t = (string * type_t) list
|
|
|
|
(* check if a variable occurs in a type_t *)
|
|
let rec occurs (t : type_t) (id: string) : bool =
|
|
match t with
|
|
| TypeVariable x -> x = id
|
|
| TypeConstructor (_, lst) -> List.exists (fun x -> (occurs x id)) lst
|
|
| TypeArrow (a,b) -> (occurs a id) || (occurs b id)
|
|
|
|
let rec subst (s: substitution_t) (t: type_t) : type_t =
|
|
match t with
|
|
| TypeVariable x -> (try List.assoc x s with Not_found -> t)
|
|
| TypeConstructor (name, args) -> TypeConstructor (name, List.map (subst s) args)
|
|
| TypeArrow (a, b) -> TypeArrow (subst s a, subst s b)
|
|
|
|
let compose (s1: substitution_t) (s2: substitution_t) : substitution_t =
|
|
(* apply s1 to s2 *)
|
|
let s2' = List.map (fun (x, t) -> (x, subst s1 t)) s2 in
|
|
(* remove duplicates *)
|
|
let s1' = List.filter (fun (x, _) -> not (List.mem_assoc x s2')) s1 in
|
|
s1' @ s2'
|
|
|
|
let compose_list (slist : substitution_t list) =
|
|
List.fold_left compose [] slist
|
|
|
|
|
|
let rec unify (t1: type_t) (t2: type_t) : substitution_t =
|
|
match t1, t2 with
|
|
| TypeVariable x, _ when not (occurs t2 x) -> [(x, t2)]
|
|
| _, TypeVariable x when not (occurs t1 x) -> [(x, t1)]
|
|
| TypeVariable x, TypeVariable y when x = y -> []
|
|
| TypeConstructor (n1, args1), TypeConstructor (n2, args2) when n1 = n2 ->
|
|
List.fold_left2 (fun s a1 a2 -> compose s (unify a1 a2)) [] args1 args2
|
|
| TypeArrow (a1, b1), TypeArrow (a2, b2) ->
|
|
let s1 = unify a1 a2 in
|
|
let s2 = unify (subst s1 b1) (subst s1 b2) in
|
|
compose s1 s2
|
|
| _ -> raise (Failure ("Type mismatch: " ^ type_t2str t1 ^ " vs " ^ type_t2str t2))
|
|
|
|
let assert_type (t1: type_t) (t2: type_t) : unit =
|
|
ignore @@ unify t1 t2
|
|
|
|
(* Test cases for the unify function *)
|
|
let%test "unify1" =
|
|
let t1 = TypeConstructor ("int", []) in
|
|
let t2 = TypeVariable "a" in
|
|
let s = unify t1 t2 in
|
|
(* Printf.printf "%s\n" (type_t2str (subst s t1)); *)
|
|
subst s t1 = subst s t2
|
|
|
|
let%test "unify2" =
|
|
let t1 = TypeArrow (TypeVariable "a", TypeVariable "b") in
|
|
let t2 = TypeArrow (TypeVariable "c", TypeVariable "d") in
|
|
let s = unify t1 t2 in
|
|
(* Printf.printf "%s\n" (s |> List.map (fun (x, t) -> x ^ " => " ^ type_t2str t) |> String.concat ", ");
|
|
Printf.printf "%s\n" (type_t2str (subst s t1)); *)
|
|
subst s t1 = subst s t2
|
|
|
|
let%test "unify3" =
|
|
let t1 = TypeArrow (TypeVariable "a", TypeVariable "b") in
|
|
let t2 = TypeArrow (TypeVariable "a", TypeVariable "c") in
|
|
let s = unify t1 t2 in
|
|
(* Printf.printf "%s\n" (s |> List.map (fun (x, t) -> x ^ " => " ^ type_t2str t) |> String.concat ", ");
|
|
Printf.printf "%s\n" (type_t2str (subst s t1)); *)
|
|
subst s t1 = subst s t2
|
|
|
|
let%test "unify4" =
|
|
let t1 = TypeConstructor ("list", [TypeVariable "a"]) in
|
|
let t2 = TypeConstructor ("list", [TypeVariable "b"]) in
|
|
let s = unify t1 t2 in
|
|
(* Printf.printf "%s\n" (s |> List.map (fun (x, t) -> x ^ " => " ^ type_t2str t) |> String.concat ", ");
|
|
Printf.printf "%s\n" (type_t2str (subst s t1)); *)
|
|
subst s t1 = subst s t2
|
|
|
|
type env_t = {
|
|
parent : env_t option;
|
|
bindings : (string, type_t) Hashtbl.t;
|
|
}
|
|
|
|
let rec lookup (env: env_t) (name: string) =
|
|
match Hashtbl.find_opt env.bindings name with
|
|
| Some ty -> ty
|
|
| None -> match env.parent with
|
|
| Some parent -> lookup parent name
|
|
| None -> raise (Failure ("Unbound variable: " ^ name))
|
|
|
|
type context_t = {
|
|
env : env_t;
|
|
constraints : (string * type_t) list;
|
|
}
|
|
|
|
let rec type_tree2type_t (tt: Parser.type_tree) =
|
|
match tt with
|
|
| TypeIdentifier(s) -> TypeConstructor(s, [])
|
|
| TypeArrow(arg, ret) -> TypeArrow(type_tree2type_t(arg), type_tree2type_t(ret))
|
|
|
|
let type_variable_counter = ref 0
|
|
let new_type_variable () =
|
|
let i = !type_variable_counter in
|
|
incr type_variable_counter;
|
|
TypeVariable ("'t" ^ string_of_int i)
|
|
|
|
let rec check_expr (ctx: context_t) (exp: Parser.expr_tree) : (type_t * substitution_t) =
|
|
match exp with
|
|
| Parser.LetExpr (l) ->
|
|
let ty = check_let_expr ctx l in
|
|
ty
|
|
| Parser.FunExpr (ftree) ->
|
|
let ty = check_fun_expr ctx ftree in
|
|
ty
|
|
| Parser.IfExpr (Parser.If (cond_expr, then_expr, else_expr)) ->
|
|
let (cond_ty, cond_s) = check_expr ctx cond_expr in
|
|
let (then_ty, then_s) = check_expr ctx then_expr in
|
|
let (else_ty, else_s) = check_expr ctx else_expr in
|
|
let s1 = [cond_s; then_s; else_s] |> compose_list in
|
|
(* currently, boolean type doesn't exist. *)
|
|
let s2 = unify cond_ty (TypeConstructor ("int", [])) in
|
|
let s3 = compose s1 s2 in
|
|
(* unify the types of then and else branches *)
|
|
let s4 = (unify (subst s3 then_ty) (subst s3 else_ty)) in
|
|
let s5 = compose s4 s3 in
|
|
(* return the type of then and else branches *)
|
|
(subst s5 then_ty, s5)
|
|
| Parser.BinOpExpr (_op, left_expr, right_expr) ->
|
|
let (left_ty, ls) = check_expr ctx left_expr in
|
|
let (right_ty, rs) = check_expr ctx right_expr in
|
|
(* unify the types of the left and right expressions *)
|
|
(* but currently, only int is supported. *)
|
|
let s = compose ls rs in
|
|
let s2 = unify (subst s left_ty) (TypeConstructor ("int", [])) in
|
|
let s3 = compose s s2 in
|
|
(* unify the types of the left and right expressions *)
|
|
let s4 = unify (subst s3 right_ty) (TypeConstructor ("int", [])) in
|
|
let s5 = compose s3 s4 in
|
|
(* return int type *)
|
|
(TypeConstructor ("int", []), s5)
|
|
| Parser.MonoOpExpr (_op, expr) ->
|
|
let (ty, s1) = check_expr ctx expr in
|
|
let s2 = unify (subst s1 ty) (TypeConstructor ("int", [])) in
|
|
let s3 = compose s1 s2 in
|
|
(* return int type *)
|
|
(TypeConstructor ("int", []), s3)
|
|
| Parser.CallExpr (Parser.Call (func_expr, arg_expr)) ->
|
|
let (func_ty, func_s) = check_expr ctx func_expr in
|
|
(* s1을 context에 적용하는 로직 추가 필요 *)
|
|
let (arg_ty, arg_s) = check_expr ctx arg_expr in
|
|
let s1 = compose func_s arg_s in
|
|
let ret_ty_var = new_type_variable () in (* 반환 타입을 위한 새 타입 변수 *)
|
|
let expected_func_ty = TypeArrow (subst s1 arg_ty, ret_ty_var) in
|
|
let s2 = unify (subst s1 func_ty) expected_func_ty in (* 실제 함수 타입과 기대 타입 통일 *)
|
|
let s3 = compose s1 s2 in
|
|
(subst s3 ret_ty_var, s3) (* 추론된 반환 타입과 최종 substitution 반환 *)
|
|
| Parser.Identifier(name) ->
|
|
let ty = lookup ctx.env name in
|
|
(ty, [])
|
|
| Parser.Number(_n) -> (TypeConstructor ("int", []), [])
|
|
|
|
and check_let_expr (ctx: context_t) (l: Parser.let_expr_tree) : type_t * substitution_t =
|
|
let (vtt, s1) = check_expr ctx l.value_expr in
|
|
let (vt, s2) = match l.type_declare with
|
|
| Some(td) ->
|
|
let tt = type_tree2type_t td in
|
|
let s_unify = unify (subst s1 vtt) tt in (* s1 적용 후 통일 *)
|
|
let s_compose = compose s1 s_unify in
|
|
(subst s_compose tt, s_compose)
|
|
| None -> (subst s1 vtt, s1) (* s1 적용 *)
|
|
in
|
|
(* s2는 값 검사와 타입 선언 통일로부터 얻은 최종 substitution *)
|
|
let final_vt = vt in
|
|
|
|
(* s2를 context에 적용하는 로직 추가 필요 *)
|
|
let new_env = {
|
|
parent = Some ctx.env; (* 수정된 context의 env 사용 필요 *)
|
|
bindings = Hashtbl.copy ctx.env.bindings; (* 수정된 context의 bindings 사용 필요 *)
|
|
} in
|
|
Hashtbl.replace new_env.bindings l.name final_vt;
|
|
let new_ctx = { ctx with env = new_env } in (* 수정된 context 사용 필요 *)
|
|
|
|
let (in_type, s3) = check_expr new_ctx l.in_expr in
|
|
let s4 = compose s2 s3 in
|
|
(in_type, s4)
|
|
|
|
and check_fun_expr (ctx: context_t) (ftree: Parser.fun_expr_tree) : type_t * substitution_t =
|
|
(* Create a new environment for the function body *)
|
|
let new_env = {
|
|
parent = Some ctx.env;
|
|
bindings = Hashtbl.create 10; (* Start with an empty table for the new scope *)
|
|
} in
|
|
|
|
(* Determine the argument type *)
|
|
let arg_type = match ftree.type_declare with
|
|
| Some(td) -> type_tree2type_t td
|
|
| None -> new_type_variable () (* Infer the type if not declared *)
|
|
in
|
|
|
|
(* Add the argument binding to the new environment *)
|
|
Hashtbl.add new_env.bindings ftree.name arg_type;
|
|
let new_ctx = { ctx with env = new_env } in
|
|
|
|
(* Check the type of the function body in the new context *)
|
|
(* 이 과정에서 substitution이 발생할 수 있음 (예: x가 int를 요구하는 곳에 사용될 때) *)
|
|
let (body_type, s_body) = check_expr new_ctx ftree.body_expr in
|
|
|
|
(* 함수 본문 검사에서 얻은 substitution을 인자 타입에 적용 *)
|
|
let final_arg_type = subst s_body arg_type in
|
|
|
|
(* 최종 함수 타입과 발생한 substitution 반환 *)
|
|
(TypeArrow (final_arg_type, body_type), s_body)
|
|
|
|
let infer_type (exp: Parser.expr_tree) : type_t =
|
|
let initial_env = { parent = None; bindings = Hashtbl.create 10 } in
|
|
let initial_ctx = { env = initial_env; constraints = [] } in
|
|
let (inferred_type, final_substitution) = check_expr initial_ctx exp in
|
|
(* 최종적으로 얻은 substitution을 전체 타입에 적용 *)
|
|
subst final_substitution inferred_type
|
|
|
|
let%test "test_infer_type1" =
|
|
let exp = Parser.BinOpExpr (Parser.Add, Parser.Number(1), Parser.Number(2)) in
|
|
let inferred_type = infer_type exp in
|
|
inferred_type = TypeConstructor ("int", [])
|
|
|
|
|
|
let%test "infer_type_fun_add" =
|
|
let input_str = "fun x -> x + 1" in
|
|
let tokens = Lexer.lex_tokens_seq input_str in
|
|
let expr_opt = Parser.get_expr_tree_from_tokens tokens in
|
|
match expr_opt with
|
|
| Some expr ->
|
|
let inferred = infer_type expr in
|
|
let expected = TypeArrow (TypeConstructor ("int", []), TypeConstructor ("int", [])) in
|
|
(* Printf.printf "Inferred: %s\n" (type_t2str inferred);
|
|
Printf.printf "Expected: %s\n" (type_t2str expected); *)
|
|
inferred = expected
|
|
| None ->
|
|
Printf.printf "Failed to parse expression: %s\n" input_str;
|
|
false
|
|
|
|
let%test "infer_type_let_fun" =
|
|
let input_str = "let f = fun x -> x + 1 in f 10" in
|
|
let tokens = Lexer.lex_tokens_seq input_str in
|
|
let expr_opt = Parser.get_expr_tree_from_tokens tokens in
|
|
match expr_opt with
|
|
| Some expr ->
|
|
let inferred = infer_type expr in
|
|
let expected = TypeConstructor ("int", []) in
|
|
(* Printf.printf "Inferred: %s\n" (type_t2str inferred);
|
|
Printf.printf "Expected: %s\n" (type_t2str expected); *)
|
|
inferred = expected
|
|
| None ->
|
|
Printf.printf "Failed to parse expression: %s\n" input_str;
|
|
false
|
|
|
|
let%test "infer_type_if" =
|
|
let input_str = "if 1 then 2 else 3" in
|
|
let tokens = Lexer.lex_tokens_seq input_str in
|
|
let expr_opt = Parser.get_expr_tree_from_tokens tokens in
|
|
match expr_opt with
|
|
| Some expr ->
|
|
let inferred = infer_type expr in
|
|
let expected = TypeConstructor ("int", []) in
|
|
(* Printf.printf "Inferred: %s\n" (type_t2str inferred);
|
|
Printf.printf "Expected: %s\n" (type_t2str expected); *)
|
|
inferred = expected
|
|
| None ->
|
|
Printf.printf "Failed to parse expression: %s\n" input_str;
|
|
false |