feat: add type checking and evaluation with type inference for expressions
This commit is contained in:
parent
2d3d8ccbd1
commit
d801a1b34f
3 changed files with 290 additions and 201 deletions
|
@ -4,10 +4,9 @@ open Calc;;
|
|||
let main () =
|
||||
let input = Sys.argv.(1) in
|
||||
let _ = Printf.printf "input: %s\n" input in
|
||||
let result = Eval.eval_str input in
|
||||
match result with
|
||||
| Eval.Int n -> Printf.printf "%d\n" n
|
||||
| _ -> failwith "Type error"
|
||||
let (result, ty) = Eval.eval_str_with_typecheck input in
|
||||
Printf.printf "type: %s\n" (Typecheck.type_t2str ty);
|
||||
Printf.printf "result: %s\n" (result |> Eval.string_of_value_type);
|
||||
;;
|
||||
|
||||
main ();;
|
16
lib/eval.ml
16
lib/eval.ml
|
@ -13,6 +13,11 @@ and function_type = {
|
|||
scope: scope;
|
||||
}
|
||||
|
||||
let string_of_value_type (v: value_type): string =
|
||||
match v with
|
||||
| Int n -> string_of_int n
|
||||
| Fun f -> Printf.sprintf "<function (%s) -> %s >" f.argname (f.body |> Parser.expr2str)
|
||||
|
||||
let rec eval_expr (scope: scope) (expr: Parser.expr_tree): value_type =
|
||||
match expr with
|
||||
| Parser.LetExpr (l) ->
|
||||
|
@ -104,6 +109,17 @@ let eval_str (str: string): value_type =
|
|||
| Some e -> eval_expr { parent = None; bindings = VariableBindingMap.empty } e
|
||||
| None -> failwith "Parse error"
|
||||
|
||||
let eval_str_with_typecheck (str: string): value_type * Typecheck.type_t =
|
||||
let tokens = Lexer.lex_tokens_seq str in
|
||||
let expr = Parser.get_expr_tree_from_tokens tokens in
|
||||
match expr with
|
||||
| Some e ->
|
||||
let ty = Typecheck.infer_type e in
|
||||
let v = eval_expr { parent = None; bindings = VariableBindingMap.empty } e in
|
||||
(v, ty)
|
||||
| None -> failwith "Parse error"
|
||||
|
||||
(* Test cases *)
|
||||
|
||||
let%test "test eval_str 1" =
|
||||
let result = eval_str "let x = 1 in x" in
|
||||
|
|
486
lib/typecheck.ml
486
lib/typecheck.ml
|
@ -1,215 +1,289 @@
|
|||
type type_v =
|
||||
| Int
|
||||
| Fun of {
|
||||
arg: type_v;
|
||||
ret: type_v;
|
||||
}
|
||||
| Generic of string
|
||||
| Universal
|
||||
| Nothing
|
||||
|
||||
type type_t =
|
||||
(* e.g. list, option, etc. *)
|
||||
| TypeConstructor of string * type_t list
|
||||
(* e.g. int, string, etc. *)
|
||||
| TypeVariable of string
|
||||
(* e.g. int -> int, string -> string, etc.
|
||||
This is a function type, where the first type is the input and the second type is the output.
|
||||
The arrow can be right associative, so (a -> b -> c) is equivalent to (a -> (b -> c)).
|
||||
*)
|
||||
| TypeArrow of type_t * type_t
|
||||
|
||||
type type_scope = {
|
||||
parent: type_scope option;
|
||||
bindings: (string, type_v) Hashtbl.t;
|
||||
generics_count: int ref;
|
||||
}
|
||||
|
||||
let make_type_scope (parent: type_scope): type_scope = {
|
||||
parent = Some (parent);
|
||||
bindings = Hashtbl.create 10;
|
||||
generics_count = parent.generics_count;
|
||||
}
|
||||
|
||||
let make_top_type_scope (): type_scope = {
|
||||
parent = None;
|
||||
bindings = Hashtbl.create 10;
|
||||
generics_count = ref 0;
|
||||
}
|
||||
|
||||
let rec typetree2type_v (t: Parser.type_tree): type_v =
|
||||
match t with
|
||||
| Parser.TypeIdentifier (x) -> if x = "int" then Int else
|
||||
failwith "not implemented (type alias is not supported yet)"
|
||||
| Parser.TypeArrow (arg, ret) -> Fun { arg = typetree2type_v arg; ret = typetree2type_v ret }
|
||||
|
||||
let rec type_v2str (t: type_v): string =
|
||||
match t with
|
||||
| Int -> "int"
|
||||
| Fun { arg = arg; ret = ret } -> Printf.sprintf "(%s -> %s)" (type_v2str arg) (type_v2str ret)
|
||||
| Generic s -> "'" ^ s
|
||||
| Universal -> "universal"
|
||||
| Nothing -> "nothing"
|
||||
|
||||
(* meet *)
|
||||
let rec intersect_type_v (a: type_v) (b: type_v): type_v =
|
||||
match a, b with
|
||||
| Universal, _ -> b
|
||||
| _, Universal -> a
|
||||
| Int, Int -> Int
|
||||
| Fun { arg = arg1; ret = ret1 }, Fun { arg = arg2; ret = ret2 } ->
|
||||
(* contravariance *)
|
||||
let arg = intersect_type_v arg1 arg2 in
|
||||
let ret = intersect_type_v ret1 ret2 in
|
||||
Fun { arg = arg; ret = ret }
|
||||
(* // TODO: fix it *)
|
||||
| Generic s1, Generic s2 when s1 = s2 -> Generic s1
|
||||
| Generic _, _ -> b
|
||||
| _ -> Nothing
|
||||
(* join *)
|
||||
(* and union_type_v (a: type_v) (b: type_v): type_v =
|
||||
match a, b with
|
||||
| Universal, _ -> Universal
|
||||
| _, Universal -> Universal
|
||||
| Int, Int -> Int
|
||||
| Fun { arg = arg1; ret = ret1 }, Fun { arg = arg2; ret = ret2 } ->
|
||||
|
||||
| Generic s1, Generic s2 when s1 = s2 -> Generic s1
|
||||
| Generic _, _ -> b
|
||||
| _ -> Nothing *)
|
||||
|
||||
(* it assumes that there is already variable binding. *)
|
||||
|
||||
let find_type_v_opt (scope: type_scope) (name: string): type_v option =
|
||||
let rec find_binding scope =
|
||||
match scope with
|
||||
| None -> None
|
||||
| Some s ->
|
||||
match Hashtbl.find_opt s.bindings name with
|
||||
| Some v -> Some v
|
||||
| None -> find_binding s.parent in
|
||||
find_binding (Some scope)
|
||||
|
||||
(* it assumes that there is already variable binding. *)
|
||||
let assert_and_get_type_v (scope: type_scope) (name: string) (expected: type_v) =
|
||||
let rec assert_binding scope =
|
||||
match scope with
|
||||
| None -> failwith "Unbound variable"
|
||||
| Some s ->
|
||||
match Hashtbl.find_opt s.bindings name with
|
||||
| Some v ->
|
||||
let subtype = intersect_type_v v expected in
|
||||
if subtype = Nothing then failwith "Type error"
|
||||
else Hashtbl.replace s.bindings name subtype;
|
||||
subtype
|
||||
| None -> assert_binding s.parent in
|
||||
assert_binding (Some scope)
|
||||
|
||||
let gen_generic_free_name (scope: type_scope): string =
|
||||
let generics_count = !(scope.generics_count) in
|
||||
let name = Printf.sprintf "%d" generics_count in
|
||||
scope.generics_count := generics_count + 1;
|
||||
name
|
||||
|
||||
let replace_generic_with (t: type_v) (from: string) (to_: type_v): type_v =
|
||||
let rec replace t =
|
||||
match t with
|
||||
| Int -> Int
|
||||
| Fun { arg = arg; ret = ret } -> Fun { arg = replace arg; ret = replace ret }
|
||||
| Generic s when s = from -> to_
|
||||
| Generic s -> Generic s
|
||||
| Universal -> Universal
|
||||
| Nothing -> Nothing in
|
||||
replace t
|
||||
|
||||
let rec typecheck_expr (scope: type_scope) (expr: Parser.expr_tree) (required_type: type_v): type_v =
|
||||
let actual_type = match expr with
|
||||
| Parser.LetExpr (l) ->
|
||||
typecheck_let_expr scope required_type l
|
||||
| Parser.FunExpr (ftree) ->
|
||||
typecheck_fun_expr scope required_type ftree
|
||||
| Parser.IfExpr (Parser.If (cond_expr, then_expr, else_expr)) ->
|
||||
typecheck_if_expr scope required_type cond_expr then_expr else_expr
|
||||
| Parser.BinOpExpr (op, left_expr, right_expr) ->
|
||||
typecheck_bin_op_expr scope required_type op left_expr right_expr
|
||||
| Parser.MonoOpExpr (_op, _expr) ->
|
||||
failwith "Not implemented"
|
||||
| Parser.CallExpr (Parser.Call (func_expr, arg_expr)) ->
|
||||
typecheck_call_expr scope required_type func_expr arg_expr
|
||||
| Parser.Identifier(name) -> assert_and_get_type_v scope name required_type
|
||||
| Parser.Number(_n) -> Int
|
||||
let type_t2str (ty: type_t) =
|
||||
let rec aux ty =
|
||||
match ty with
|
||||
| TypeConstructor (name, args) ->
|
||||
name ^ if List.is_empty args then "" else
|
||||
" of (" ^ String.concat ", " (List.map aux args) ^ ")"
|
||||
| TypeVariable x -> x
|
||||
| TypeArrow (a, b) -> "(" ^ aux a ^ " -> " ^ aux b ^ ")"
|
||||
in
|
||||
let subtype = intersect_type_v required_type actual_type in
|
||||
if subtype = Nothing then
|
||||
failwith (Printf.sprintf "Type error: expect %s but actual %s"
|
||||
(type_v2str required_type) (type_v2str actual_type)
|
||||
)
|
||||
else subtype
|
||||
and typecheck_let_expr (scope: type_scope) (required_type: type_v) ({
|
||||
name = name;
|
||||
value_expr = value_expr;
|
||||
in_expr = in_expr;
|
||||
type_declare = type_decl;
|
||||
}: Parser.let_expr_tree): type_v =
|
||||
let value_reqired_type = type_decl |> Option.map typetree2type_v |> Option.value ~default: Universal in
|
||||
let value_type = typecheck_expr scope value_expr value_reqired_type in
|
||||
let new_scope = make_type_scope scope in
|
||||
Hashtbl.add new_scope.bindings name value_type;
|
||||
typecheck_expr new_scope in_expr required_type
|
||||
and typecheck_fun_expr (scope: type_scope) (_required_type: type_v) ({
|
||||
name = argname;
|
||||
body_expr = body_expr;
|
||||
type_declare = type_decl;
|
||||
}: Parser.fun_expr_tree): type_v =
|
||||
let default_type = Generic (gen_generic_free_name scope) in
|
||||
let arg_type = type_decl |> Option.map typetree2type_v |> Option.value ~default: default_type in
|
||||
let new_scope = make_type_scope scope in
|
||||
Hashtbl.add new_scope.bindings argname arg_type;
|
||||
(* unreachable because *)
|
||||
let ret_type = typecheck_expr new_scope body_expr Universal in
|
||||
let arg_type = Hashtbl.find new_scope.bindings argname in
|
||||
Printf.printf "arg: %s, ret: %s\n" (type_v2str arg_type) (type_v2str ret_type);
|
||||
aux ty
|
||||
|
||||
Fun { arg = arg_type; ret = ret_type }
|
||||
and typecheck_if_expr (scope: type_scope) (required_type: type_v)
|
||||
(cond_expr: Parser.expr_tree) (then_expr: Parser.expr_tree) (else_expr: Parser.expr_tree): type_v =
|
||||
let _ = typecheck_expr scope cond_expr Int in
|
||||
let then_type = typecheck_expr scope then_expr required_type in
|
||||
let else_type = typecheck_expr scope else_expr required_type in
|
||||
intersect_type_v then_type else_type
|
||||
and typecheck_bin_op_expr (scope: type_scope) (_required_type: type_v)
|
||||
(_op: Parser.bin_op_type) (left_expr: Parser.expr_tree) (right_expr: Parser.expr_tree): type_v =
|
||||
(* default int *)
|
||||
let _ = typecheck_expr scope left_expr Int in
|
||||
let _ = typecheck_expr scope right_expr Int in
|
||||
Int
|
||||
and typecheck_call_expr (scope: type_scope) (_required_type: type_v)
|
||||
(func_expr: Parser.expr_tree) (arg_expr: Parser.expr_tree): type_v =
|
||||
let func_type = typecheck_expr scope func_expr Universal in
|
||||
Printf.printf "func_type: %s\n" (type_v2str func_type);
|
||||
type substitution_t = (string * type_t) list
|
||||
|
||||
match func_type with
|
||||
| Fun { arg = arg_type; ret = ret_type } ->
|
||||
let mono_arg_type = typecheck_expr scope arg_expr arg_type in
|
||||
Printf.printf "arg_type: %s\n" (type_v2str mono_arg_type);
|
||||
begin match arg_type with
|
||||
| Generic s ->
|
||||
(* instance *)
|
||||
let new_ret_type = replace_generic_with ret_type s mono_arg_type in
|
||||
Printf.printf "new_ret_type: %s\n" (type_v2str new_ret_type);
|
||||
new_ret_type
|
||||
| _ -> ret_type
|
||||
end
|
||||
| _ -> failwith "Type error"
|
||||
(* check if a variable occurs in a type_t *)
|
||||
let rec occurs (t : type_t) (id: string) : bool =
|
||||
match t with
|
||||
| TypeVariable x -> x = id
|
||||
| TypeConstructor (_, lst) -> List.exists (fun x -> (occurs x id)) lst
|
||||
| TypeArrow (a,b) -> (occurs a id) || (occurs b id)
|
||||
|
||||
let typecheck (expr: Parser.expr_tree): type_v =
|
||||
typecheck_expr (make_top_type_scope()) expr Universal
|
||||
let rec subst (s: substitution_t) (t: type_t) : type_t =
|
||||
match t with
|
||||
| TypeVariable x -> (try List.assoc x s with Not_found -> t)
|
||||
| TypeConstructor (name, args) -> TypeConstructor (name, List.map (subst s) args)
|
||||
| TypeArrow (a, b) -> TypeArrow (subst s a, subst s b)
|
||||
|
||||
let typecheck_result (expr: Parser.expr_tree): (type_v, exn) result =
|
||||
try
|
||||
let t = typecheck expr in
|
||||
Result.Ok (t)
|
||||
with e -> Result.Error e
|
||||
let compose (s1: substitution_t) (s2: substitution_t) : substitution_t =
|
||||
(* apply s1 to s2 *)
|
||||
let s2' = List.map (fun (x, t) -> (x, subst s1 t)) s2 in
|
||||
(* remove duplicates *)
|
||||
let s1' = List.filter (fun (x, _) -> not (List.mem_assoc x s2')) s1 in
|
||||
s1' @ s2'
|
||||
|
||||
let test_typecheck (content:string) =
|
||||
let tokens = Lexer.lex_tokens_seq content in
|
||||
let expr = Parser.get_expr_tree_from_tokens tokens in
|
||||
match expr with
|
||||
| Some e -> typecheck_result e
|
||||
| None -> Result.Error (Failure "parse error")
|
||||
let compose_list (slist : substitution_t list) =
|
||||
List.fold_left compose [] slist
|
||||
|
||||
let%test "typecheck 1" =
|
||||
let expr = "let x = fun y -> y in x 1" in
|
||||
match test_typecheck expr with
|
||||
| Result.Ok (t) -> Printf.printf "%s\n" (type_v2str t); t = Int
|
||||
| Result.Error _ -> Printf.printf "error\n"; false
|
||||
|
||||
let rec unify (t1: type_t) (t2: type_t) : substitution_t =
|
||||
match t1, t2 with
|
||||
| TypeVariable x, _ when not (occurs t2 x) -> [(x, t2)]
|
||||
| _, TypeVariable x when not (occurs t1 x) -> [(x, t1)]
|
||||
| TypeVariable x, TypeVariable y when x = y -> []
|
||||
| TypeConstructor (n1, args1), TypeConstructor (n2, args2) when n1 = n2 ->
|
||||
List.fold_left2 (fun s a1 a2 -> compose s (unify a1 a2)) [] args1 args2
|
||||
| TypeArrow (a1, b1), TypeArrow (a2, b2) ->
|
||||
let s1 = unify a1 a2 in
|
||||
let s2 = unify (subst s1 b1) (subst s1 b2) in
|
||||
compose s1 s2
|
||||
| _ -> raise (Failure ("Type mismatch: " ^ type_t2str t1 ^ " vs " ^ type_t2str t2))
|
||||
|
||||
let assert_type (t1: type_t) (t2: type_t) : unit =
|
||||
ignore @@ unify t1 t2
|
||||
|
||||
(* Test cases for the unify function *)
|
||||
let%test "unify1" =
|
||||
let t1 = TypeConstructor ("int", []) in
|
||||
let t2 = TypeVariable "a" in
|
||||
let s = unify t1 t2 in
|
||||
(* Printf.printf "%s\n" (type_t2str (subst s t1)); *)
|
||||
subst s t1 = subst s t2
|
||||
|
||||
let%test "unify2" =
|
||||
let t1 = TypeArrow (TypeVariable "a", TypeVariable "b") in
|
||||
let t2 = TypeArrow (TypeVariable "c", TypeVariable "d") in
|
||||
let s = unify t1 t2 in
|
||||
(* Printf.printf "%s\n" (s |> List.map (fun (x, t) -> x ^ " => " ^ type_t2str t) |> String.concat ", ");
|
||||
Printf.printf "%s\n" (type_t2str (subst s t1)); *)
|
||||
subst s t1 = subst s t2
|
||||
|
||||
let%test "unify3" =
|
||||
let t1 = TypeArrow (TypeVariable "a", TypeVariable "b") in
|
||||
let t2 = TypeArrow (TypeVariable "a", TypeVariable "c") in
|
||||
let s = unify t1 t2 in
|
||||
(* Printf.printf "%s\n" (s |> List.map (fun (x, t) -> x ^ " => " ^ type_t2str t) |> String.concat ", ");
|
||||
Printf.printf "%s\n" (type_t2str (subst s t1)); *)
|
||||
subst s t1 = subst s t2
|
||||
|
||||
let%test "unify4" =
|
||||
let t1 = TypeConstructor ("list", [TypeVariable "a"]) in
|
||||
let t2 = TypeConstructor ("list", [TypeVariable "b"]) in
|
||||
let s = unify t1 t2 in
|
||||
(* Printf.printf "%s\n" (s |> List.map (fun (x, t) -> x ^ " => " ^ type_t2str t) |> String.concat ", ");
|
||||
Printf.printf "%s\n" (type_t2str (subst s t1)); *)
|
||||
subst s t1 = subst s t2
|
||||
|
||||
type env_t = {
|
||||
parent : env_t option;
|
||||
bindings : (string, type_t) Hashtbl.t;
|
||||
}
|
||||
|
||||
let rec lookup (env: env_t) (name: string) =
|
||||
match Hashtbl.find_opt env.bindings name with
|
||||
| Some ty -> ty
|
||||
| None -> match env.parent with
|
||||
| Some parent -> lookup parent name
|
||||
| None -> raise (Failure ("Unbound variable: " ^ name))
|
||||
|
||||
type context_t = {
|
||||
env : env_t;
|
||||
constraints : (string * type_t) list;
|
||||
}
|
||||
|
||||
let rec type_tree2type_t (tt: Parser.type_tree) =
|
||||
match tt with
|
||||
| TypeIdentifier(s) -> TypeConstructor(s, [])
|
||||
| TypeArrow(arg, ret) -> TypeArrow(type_tree2type_t(arg), type_tree2type_t(ret))
|
||||
|
||||
let type_variable_counter = ref 0
|
||||
let new_type_variable () =
|
||||
let i = !type_variable_counter in
|
||||
incr type_variable_counter;
|
||||
TypeVariable ("'t" ^ string_of_int i)
|
||||
|
||||
let rec check_expr (ctx: context_t) (exp: Parser.expr_tree) : (type_t * substitution_t) =
|
||||
match exp with
|
||||
| Parser.LetExpr (l) ->
|
||||
let ty = check_let_expr ctx l in
|
||||
ty
|
||||
| Parser.FunExpr (ftree) ->
|
||||
let ty = check_fun_expr ctx ftree in
|
||||
ty
|
||||
| Parser.IfExpr (Parser.If (cond_expr, then_expr, else_expr)) ->
|
||||
let (cond_ty, cond_s) = check_expr ctx cond_expr in
|
||||
let (then_ty, then_s) = check_expr ctx then_expr in
|
||||
let (else_ty, else_s) = check_expr ctx else_expr in
|
||||
let s1 = [cond_s; then_s; else_s] |> compose_list in
|
||||
(* currently, boolean type doesn't exist. *)
|
||||
let s2 = unify cond_ty (TypeConstructor ("int", [])) in
|
||||
let s3 = compose s1 s2 in
|
||||
(* unify the types of then and else branches *)
|
||||
let s4 = (unify (subst s3 then_ty) (subst s3 else_ty)) in
|
||||
let s5 = compose s4 s3 in
|
||||
(* return the type of then and else branches *)
|
||||
(subst s5 then_ty, s5)
|
||||
| Parser.BinOpExpr (_op, left_expr, right_expr) ->
|
||||
let (left_ty, ls) = check_expr ctx left_expr in
|
||||
let (right_ty, rs) = check_expr ctx right_expr in
|
||||
(* unify the types of the left and right expressions *)
|
||||
(* but currently, only int is supported. *)
|
||||
let s = compose ls rs in
|
||||
let s2 = unify (subst s left_ty) (TypeConstructor ("int", [])) in
|
||||
let s3 = compose s s2 in
|
||||
(* unify the types of the left and right expressions *)
|
||||
let s4 = unify (subst s3 right_ty) (TypeConstructor ("int", [])) in
|
||||
let s5 = compose s3 s4 in
|
||||
(* return int type *)
|
||||
(TypeConstructor ("int", []), s5)
|
||||
| Parser.MonoOpExpr (_op, expr) ->
|
||||
let (ty, s1) = check_expr ctx expr in
|
||||
let s2 = unify (subst s1 ty) (TypeConstructor ("int", [])) in
|
||||
let s3 = compose s1 s2 in
|
||||
(* return int type *)
|
||||
(TypeConstructor ("int", []), s3)
|
||||
| Parser.CallExpr (Parser.Call (func_expr, arg_expr)) ->
|
||||
let (func_ty, func_s) = check_expr ctx func_expr in
|
||||
(* s1을 context에 적용하는 로직 추가 필요 *)
|
||||
let (arg_ty, arg_s) = check_expr ctx arg_expr in
|
||||
let s1 = compose func_s arg_s in
|
||||
let ret_ty_var = new_type_variable () in (* 반환 타입을 위한 새 타입 변수 *)
|
||||
let expected_func_ty = TypeArrow (subst s1 arg_ty, ret_ty_var) in
|
||||
let s2 = unify (subst s1 func_ty) expected_func_ty in (* 실제 함수 타입과 기대 타입 통일 *)
|
||||
let s3 = compose s1 s2 in
|
||||
(subst s3 ret_ty_var, s3) (* 추론된 반환 타입과 최종 substitution 반환 *)
|
||||
| Parser.Identifier(name) ->
|
||||
let ty = lookup ctx.env name in
|
||||
(ty, [])
|
||||
| Parser.Number(_n) -> (TypeConstructor ("int", []), [])
|
||||
|
||||
and check_let_expr (ctx: context_t) (l: Parser.let_expr_tree) : type_t * substitution_t =
|
||||
let (vtt, s1) = check_expr ctx l.value_expr in
|
||||
let (vt, s2) = match l.type_declare with
|
||||
| Some(td) ->
|
||||
let tt = type_tree2type_t td in
|
||||
let s_unify = unify (subst s1 vtt) tt in (* s1 적용 후 통일 *)
|
||||
let s_compose = compose s1 s_unify in
|
||||
(subst s_compose tt, s_compose)
|
||||
| None -> (subst s1 vtt, s1) (* s1 적용 *)
|
||||
in
|
||||
(* s2는 값 검사와 타입 선언 통일로부터 얻은 최종 substitution *)
|
||||
let final_vt = vt in
|
||||
|
||||
(* s2를 context에 적용하는 로직 추가 필요 *)
|
||||
let new_env = {
|
||||
parent = Some ctx.env; (* 수정된 context의 env 사용 필요 *)
|
||||
bindings = Hashtbl.copy ctx.env.bindings; (* 수정된 context의 bindings 사용 필요 *)
|
||||
} in
|
||||
Hashtbl.replace new_env.bindings l.name final_vt;
|
||||
let new_ctx = { ctx with env = new_env } in (* 수정된 context 사용 필요 *)
|
||||
|
||||
let (in_type, s3) = check_expr new_ctx l.in_expr in
|
||||
let s4 = compose s2 s3 in
|
||||
(in_type, s4)
|
||||
|
||||
and check_fun_expr (ctx: context_t) (ftree: Parser.fun_expr_tree) : type_t * substitution_t =
|
||||
(* Create a new environment for the function body *)
|
||||
let new_env = {
|
||||
parent = Some ctx.env;
|
||||
bindings = Hashtbl.create 10; (* Start with an empty table for the new scope *)
|
||||
} in
|
||||
|
||||
(* Determine the argument type *)
|
||||
let arg_type = match ftree.type_declare with
|
||||
| Some(td) -> type_tree2type_t td
|
||||
| None -> new_type_variable () (* Infer the type if not declared *)
|
||||
in
|
||||
|
||||
(* Add the argument binding to the new environment *)
|
||||
Hashtbl.add new_env.bindings ftree.name arg_type;
|
||||
let new_ctx = { ctx with env = new_env } in
|
||||
|
||||
(* Check the type of the function body in the new context *)
|
||||
(* 이 과정에서 substitution이 발생할 수 있음 (예: x가 int를 요구하는 곳에 사용될 때) *)
|
||||
let (body_type, s_body) = check_expr new_ctx ftree.body_expr in
|
||||
|
||||
(* 함수 본문 검사에서 얻은 substitution을 인자 타입에 적용 *)
|
||||
let final_arg_type = subst s_body arg_type in
|
||||
|
||||
(* 최종 함수 타입과 발생한 substitution 반환 *)
|
||||
(TypeArrow (final_arg_type, body_type), s_body)
|
||||
|
||||
let infer_type (exp: Parser.expr_tree) : type_t =
|
||||
let initial_env = { parent = None; bindings = Hashtbl.create 10 } in
|
||||
let initial_ctx = { env = initial_env; constraints = [] } in
|
||||
let (inferred_type, final_substitution) = check_expr initial_ctx exp in
|
||||
(* 최종적으로 얻은 substitution을 전체 타입에 적용 *)
|
||||
subst final_substitution inferred_type
|
||||
|
||||
let%test "test_infer_type1" =
|
||||
let exp = Parser.BinOpExpr (Parser.Add, Parser.Number(1), Parser.Number(2)) in
|
||||
let inferred_type = infer_type exp in
|
||||
inferred_type = TypeConstructor ("int", [])
|
||||
|
||||
|
||||
let%test "infer_type_fun_add" =
|
||||
let input_str = "fun x -> x + 1" in
|
||||
let tokens = Lexer.lex_tokens_seq input_str in
|
||||
let expr_opt = Parser.get_expr_tree_from_tokens tokens in
|
||||
match expr_opt with
|
||||
| Some expr ->
|
||||
let inferred = infer_type expr in
|
||||
let expected = TypeArrow (TypeConstructor ("int", []), TypeConstructor ("int", [])) in
|
||||
(* Printf.printf "Inferred: %s\n" (type_t2str inferred);
|
||||
Printf.printf "Expected: %s\n" (type_t2str expected); *)
|
||||
inferred = expected
|
||||
| None ->
|
||||
Printf.printf "Failed to parse expression: %s\n" input_str;
|
||||
false
|
||||
|
||||
let%test "infer_type_let_fun" =
|
||||
let input_str = "let f = fun x -> x + 1 in f 10" in
|
||||
let tokens = Lexer.lex_tokens_seq input_str in
|
||||
let expr_opt = Parser.get_expr_tree_from_tokens tokens in
|
||||
match expr_opt with
|
||||
| Some expr ->
|
||||
let inferred = infer_type expr in
|
||||
let expected = TypeConstructor ("int", []) in
|
||||
(* Printf.printf "Inferred: %s\n" (type_t2str inferred);
|
||||
Printf.printf "Expected: %s\n" (type_t2str expected); *)
|
||||
inferred = expected
|
||||
| None ->
|
||||
Printf.printf "Failed to parse expression: %s\n" input_str;
|
||||
false
|
||||
|
||||
let%test "infer_type_if" =
|
||||
let input_str = "if 1 then 2 else 3" in
|
||||
let tokens = Lexer.lex_tokens_seq input_str in
|
||||
let expr_opt = Parser.get_expr_tree_from_tokens tokens in
|
||||
match expr_opt with
|
||||
| Some expr ->
|
||||
let inferred = infer_type expr in
|
||||
let expected = TypeConstructor ("int", []) in
|
||||
(* Printf.printf "Inferred: %s\n" (type_t2str inferred);
|
||||
Printf.printf "Expected: %s\n" (type_t2str expected); *)
|
||||
inferred = expected
|
||||
| None ->
|
||||
Printf.printf "Failed to parse expression: %s\n" input_str;
|
||||
false
|
Loading…
Add table
Reference in a new issue