feat: implement type checking for expressions and add type scope management
This commit is contained in:
		
							parent
							
								
									b503bd29e8
								
							
						
					
					
						commit
						2d3d8ccbd1
					
				
					 3 changed files with 245 additions and 34 deletions
				
			
		|  | @ -63,7 +63,7 @@ and eval_fun_expr scope (ftree: Parser.fun_expr_tree) = | |||
|   Fun { argname = ftree.name; body = ftree.body_expr; scope = scope } | ||||
| and eval_bin_op_expr scope op left_expr right_expr =  | ||||
|   let left = eval_expr scope left_expr in | ||||
|     let right = eval_expr scope right_expr in | ||||
|   let right = eval_expr scope right_expr in | ||||
|     (match op with | ||||
|     | Add -> ( | ||||
|       match (left, right) with | ||||
|  |  | |||
|  | @ -371,50 +371,46 @@ let get_expr_tree_from_tokens (tokens: (Token.t * Lexer.lexer_context) Seq.t): e | |||
|   | Some (e, _) -> Some e | ||||
|   | None -> None | ||||
| 
 | ||||
| let normalize_calc_string (s: string): string = | ||||
|   Lexer.lex_tokens_seq s |> get_expr_tree_from_tokens |> Option.map expr2str |> Option.value ~default:"" | ||||
| 
 | ||||
| let%test "test get_expr_tree_from_tokens 1" = | ||||
|   let tokens = Lexer.lex_tokens_seq "let x = 1 in\n  x" in | ||||
|   match get_expr_tree_from_tokens tokens with | ||||
|   | Some e -> expr2str e = "let x = 1 in\nx" | ||||
|   | None -> false | ||||
|   let actual = normalize_calc_string "let x = 1 in\n  x" in | ||||
|   let expected = "let x = 1 in\nx" in | ||||
|   actual = expected | ||||
| 
 | ||||
| let%test "test get_expr_tree_from_tokens 2" = | ||||
|   let tokens = Lexer.lex_tokens_seq "fun x -> x" in | ||||
|   match get_expr_tree_from_tokens tokens with | ||||
|   | Some e -> expr2str e = "fun x ->\nx" | ||||
|   | None -> false | ||||
|   let actual = normalize_calc_string "fun x -> x" in | ||||
|   let expected = "fun x ->\nx" in | ||||
|   actual = expected | ||||
| 
 | ||||
| let%test "test get_expr_tree_from_tokens 3" = | ||||
|   let tokens = Lexer.lex_tokens_seq "if 1 then 2 else 3" in | ||||
|   match get_expr_tree_from_tokens tokens with | ||||
|   | Some e -> expr2str e = "if 1 then 2 else 3" | ||||
|   | None -> false | ||||
|   let actual = normalize_calc_string "if 1 then 2 else 3" in | ||||
|   let expected = "if 1 then 2 else 3" in | ||||
|   actual = expected | ||||
| 
 | ||||
| let%test "test get_expr_tree_from_tokens 4" = | ||||
|   let tokens = Lexer.lex_tokens_seq "1 + 2 * 3" in | ||||
|   match get_expr_tree_from_tokens tokens with | ||||
|   | Some e -> expr2str e = "1 + 2 * 3" | ||||
|   | None -> false | ||||
|   let actual = normalize_calc_string "1 + 2 * 3" in | ||||
|   let expected = "1 + 2 * 3" in | ||||
|   actual = expected | ||||
| 
 | ||||
| let%test "test get_expr_tree_from_tokens 5" = | ||||
|   let tokens = Lexer.lex_tokens_seq "x 1 2" in | ||||
|   match get_expr_tree_from_tokens tokens with | ||||
|   | Some e -> expr2str e = "x(1)(2)" | ||||
|   | None -> false | ||||
|   let actual = normalize_calc_string "x 1 2" in | ||||
|   let expected = "x(1)(2)" in | ||||
|   actual = expected | ||||
| 
 | ||||
| let%test "test get_expr_tree_from_tokens 6 with type" = | ||||
|   let tokens = Lexer.lex_tokens_seq "let x: int = 1 in\n  x" in | ||||
|   match get_expr_tree_from_tokens tokens with | ||||
|   | Some e -> expr2str e = "let x: int = 1 in\nx" | ||||
|   | None -> false | ||||
|   let actual = normalize_calc_string "let x: int = 1 in\n  x" in | ||||
|   let expected = "let x: int = 1 in\nx" in | ||||
|   actual = expected | ||||
| 
 | ||||
| let%test "test get_expr_tree_from_tokens 7 with type" = | ||||
|   let tokens = Lexer.lex_tokens_seq "fun (x: int) -> x" in | ||||
|   match get_expr_tree_from_tokens tokens with | ||||
|   | Some e -> expr2str e = "fun (x: int) ->\nx" | ||||
|   | None -> false | ||||
|   let actual = normalize_calc_string "fun (x: int) -> x" in | ||||
|   let expected = "fun (x: int) ->\nx" in | ||||
|   actual = expected | ||||
| 
 | ||||
| let%test "test get_expr_tree_from_tokens 8" = | ||||
|   let tokens = Lexer.lex_tokens_seq "fun (x) -> x" in | ||||
|   match get_expr_tree_from_tokens tokens with | ||||
|   | Some e -> expr2str e = "fun x ->\nx" | ||||
|   | None -> false | ||||
|   let actual = normalize_calc_string "fun (x) -> x" in | ||||
|   let expected = "fun x ->\nx" in | ||||
|   actual = expected | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										215
									
								
								lib/typecheck.ml
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										215
									
								
								lib/typecheck.ml
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,215 @@ | |||
| type type_v =  | ||||
|   | Int | ||||
|   | Fun of { | ||||
|     arg: type_v; | ||||
|     ret: type_v; | ||||
|   } | ||||
|   | Generic of string | ||||
|   | Universal | ||||
|   | Nothing | ||||
| 
 | ||||
| 
 | ||||
| type type_scope = { | ||||
|   parent: type_scope option; | ||||
|   bindings: (string, type_v) Hashtbl.t; | ||||
|   generics_count: int ref; | ||||
| } | ||||
| 
 | ||||
| let make_type_scope (parent: type_scope): type_scope = { | ||||
|   parent = Some (parent); | ||||
|   bindings = Hashtbl.create 10; | ||||
|   generics_count = parent.generics_count; | ||||
| } | ||||
| 
 | ||||
| let make_top_type_scope (): type_scope = { | ||||
|   parent = None; | ||||
|   bindings = Hashtbl.create 10; | ||||
|   generics_count = ref 0; | ||||
| } | ||||
| 
 | ||||
| let rec typetree2type_v (t: Parser.type_tree): type_v =  | ||||
|   match t with | ||||
|   | Parser.TypeIdentifier (x) -> if x = "int" then Int else  | ||||
|       failwith "not implemented (type alias is not supported yet)" | ||||
|   | Parser.TypeArrow (arg, ret) -> Fun { arg = typetree2type_v arg; ret = typetree2type_v ret } | ||||
| 
 | ||||
| let rec type_v2str (t: type_v): string =  | ||||
|   match t with | ||||
|   | Int -> "int" | ||||
|   | Fun { arg = arg; ret = ret } -> Printf.sprintf "(%s -> %s)" (type_v2str arg) (type_v2str ret) | ||||
|   | Generic s -> "'" ^ s | ||||
|   | Universal -> "universal" | ||||
|   | Nothing -> "nothing" | ||||
| 
 | ||||
| (* meet *) | ||||
| let rec intersect_type_v (a: type_v) (b: type_v): type_v =  | ||||
|   match a, b with | ||||
|   | Universal, _ -> b | ||||
|   | _, Universal -> a | ||||
|   | Int, Int -> Int | ||||
|   | Fun { arg = arg1; ret = ret1 }, Fun { arg = arg2; ret = ret2 } ->  | ||||
|     (* contravariance *) | ||||
|     let arg = intersect_type_v arg1 arg2 in | ||||
|     let ret = intersect_type_v ret1 ret2 in | ||||
|     Fun { arg = arg; ret = ret } | ||||
|     (* // TODO: fix it *) | ||||
|   | Generic s1, Generic s2 when s1 = s2 -> Generic s1  | ||||
|   | Generic _, _ -> b | ||||
|   | _ -> Nothing | ||||
| (* join *) | ||||
| (* and union_type_v (a: type_v) (b: type_v): type_v =  | ||||
|   match a, b with | ||||
|   | Universal, _ -> Universal | ||||
|   | _, Universal -> Universal | ||||
|   | Int, Int -> Int | ||||
|   | Fun { arg = arg1; ret = ret1 }, Fun { arg = arg2; ret = ret2 } ->  | ||||
| 
 | ||||
|   | Generic s1, Generic s2 when s1 = s2 -> Generic s1 | ||||
|   | Generic _, _ -> b | ||||
|   | _ -> Nothing *) | ||||
| 
 | ||||
| (* it assumes that there is already variable binding. *) | ||||
| 
 | ||||
| let find_type_v_opt (scope: type_scope) (name: string): type_v option =  | ||||
|   let rec find_binding scope =  | ||||
|     match scope with | ||||
|     | None -> None | ||||
|     | Some s ->  | ||||
|       match Hashtbl.find_opt s.bindings name with | ||||
|       | Some v -> Some v | ||||
|       | None -> find_binding s.parent in | ||||
|   find_binding (Some scope) | ||||
| 
 | ||||
| (* it assumes that there is already variable binding. *) | ||||
| let assert_and_get_type_v (scope: type_scope) (name: string) (expected: type_v) =  | ||||
|   let rec assert_binding scope =  | ||||
|     match scope with | ||||
|     | None -> failwith "Unbound variable" | ||||
|     | Some s ->  | ||||
|       match Hashtbl.find_opt s.bindings name with | ||||
|       | Some v -> | ||||
|         let subtype = intersect_type_v v expected in | ||||
|         if subtype = Nothing then failwith "Type error" | ||||
|         else Hashtbl.replace s.bindings name subtype; | ||||
|           subtype | ||||
|       | None -> assert_binding s.parent in | ||||
|   assert_binding (Some scope) | ||||
| 
 | ||||
| let gen_generic_free_name (scope: type_scope): string =  | ||||
|   let generics_count = !(scope.generics_count) in | ||||
|   let name = Printf.sprintf "%d" generics_count in | ||||
|   scope.generics_count := generics_count + 1; | ||||
|   name | ||||
| 
 | ||||
| let replace_generic_with (t: type_v) (from: string) (to_: type_v): type_v =  | ||||
|   let rec replace t =  | ||||
|     match t with | ||||
|     | Int -> Int | ||||
|     | Fun { arg = arg; ret = ret } -> Fun { arg = replace arg; ret = replace ret } | ||||
|     | Generic s when s = from -> to_ | ||||
|     | Generic s -> Generic s | ||||
|     | Universal -> Universal | ||||
|     | Nothing -> Nothing in | ||||
|   replace t | ||||
| 
 | ||||
| let rec typecheck_expr (scope: type_scope) (expr: Parser.expr_tree) (required_type: type_v): type_v =  | ||||
|   let actual_type = match expr with | ||||
|   | Parser.LetExpr (l) ->  | ||||
|     typecheck_let_expr scope required_type l | ||||
|   | Parser.FunExpr (ftree) ->  | ||||
|     typecheck_fun_expr scope required_type ftree | ||||
|   | Parser.IfExpr (Parser.If (cond_expr, then_expr, else_expr)) ->  | ||||
|     typecheck_if_expr scope required_type cond_expr then_expr else_expr | ||||
|   | Parser.BinOpExpr (op, left_expr, right_expr) ->  | ||||
|     typecheck_bin_op_expr scope required_type op left_expr right_expr | ||||
|   | Parser.MonoOpExpr (_op, _expr) ->  | ||||
|     failwith "Not implemented" | ||||
|   | Parser.CallExpr (Parser.Call (func_expr, arg_expr)) ->  | ||||
|     typecheck_call_expr scope required_type func_expr arg_expr | ||||
|   | Parser.Identifier(name) -> assert_and_get_type_v scope name required_type | ||||
|   | Parser.Number(_n) -> Int | ||||
|   in  | ||||
|   let subtype = intersect_type_v required_type actual_type in | ||||
|   if subtype = Nothing then  | ||||
|     failwith (Printf.sprintf "Type error: expect %s but actual %s"  | ||||
|       (type_v2str required_type) (type_v2str actual_type) | ||||
|     ) | ||||
|   else subtype | ||||
| and typecheck_let_expr (scope: type_scope) (required_type: type_v) ({ | ||||
|   name = name; | ||||
|   value_expr = value_expr; | ||||
|   in_expr = in_expr;  | ||||
|   type_declare = type_decl; | ||||
| }: Parser.let_expr_tree): type_v = | ||||
|   let value_reqired_type = type_decl |> Option.map typetree2type_v |> Option.value ~default: Universal in | ||||
|   let value_type = typecheck_expr scope value_expr value_reqired_type in | ||||
|   let new_scope = make_type_scope scope in | ||||
|   Hashtbl.add new_scope.bindings name value_type; | ||||
|   typecheck_expr new_scope in_expr required_type | ||||
| and typecheck_fun_expr (scope: type_scope) (_required_type: type_v) ({ | ||||
|   name = argname; | ||||
|   body_expr = body_expr; | ||||
|   type_declare = type_decl; | ||||
| }: Parser.fun_expr_tree): type_v = | ||||
|   let default_type = Generic (gen_generic_free_name scope) in | ||||
|   let arg_type = type_decl |> Option.map typetree2type_v |> Option.value ~default: default_type in | ||||
|   let new_scope = make_type_scope scope in | ||||
|   Hashtbl.add new_scope.bindings argname arg_type; | ||||
|   (* unreachable because  *) | ||||
|   let ret_type = typecheck_expr new_scope body_expr Universal in | ||||
|   let arg_type = Hashtbl.find new_scope.bindings argname in | ||||
|   Printf.printf "arg: %s, ret: %s\n" (type_v2str arg_type) (type_v2str ret_type); | ||||
|    | ||||
|   Fun { arg = arg_type; ret = ret_type } | ||||
| and typecheck_if_expr (scope: type_scope) (required_type: type_v)  | ||||
|   (cond_expr: Parser.expr_tree) (then_expr: Parser.expr_tree) (else_expr: Parser.expr_tree): type_v =  | ||||
|   let _ = typecheck_expr scope cond_expr Int in | ||||
|   let then_type = typecheck_expr scope then_expr required_type in | ||||
|   let else_type = typecheck_expr scope else_expr required_type in | ||||
|   intersect_type_v then_type else_type | ||||
| and typecheck_bin_op_expr (scope: type_scope) (_required_type: type_v)  | ||||
|   (_op: Parser.bin_op_type) (left_expr: Parser.expr_tree) (right_expr: Parser.expr_tree): type_v =  | ||||
|   (* default int *) | ||||
|   let _ = typecheck_expr scope left_expr Int in | ||||
|   let _ = typecheck_expr scope right_expr Int in | ||||
|   Int | ||||
| and typecheck_call_expr (scope: type_scope) (_required_type: type_v) | ||||
|   (func_expr: Parser.expr_tree) (arg_expr: Parser.expr_tree): type_v = | ||||
|   let func_type = typecheck_expr scope func_expr Universal in | ||||
|   Printf.printf "func_type: %s\n" (type_v2str func_type); | ||||
| 
 | ||||
|   match func_type with | ||||
|   | Fun { arg = arg_type; ret = ret_type } ->  | ||||
|     let mono_arg_type = typecheck_expr scope arg_expr arg_type in | ||||
|     Printf.printf "arg_type: %s\n" (type_v2str mono_arg_type); | ||||
|     begin match arg_type with  | ||||
|      | Generic s ->  | ||||
|       (* instance *) | ||||
|       let new_ret_type = replace_generic_with ret_type s mono_arg_type in | ||||
|       Printf.printf "new_ret_type: %s\n" (type_v2str new_ret_type); | ||||
|       new_ret_type | ||||
|      | _ -> ret_type | ||||
|      end | ||||
|   | _ -> failwith "Type error" | ||||
| 
 | ||||
| let typecheck (expr: Parser.expr_tree): type_v = | ||||
|   typecheck_expr (make_top_type_scope()) expr Universal | ||||
| 
 | ||||
| let typecheck_result (expr: Parser.expr_tree): (type_v, exn) result = | ||||
|   try  | ||||
|     let t = typecheck expr in | ||||
|     Result.Ok (t) | ||||
|   with e -> Result.Error e | ||||
| 
 | ||||
| let test_typecheck (content:string) = | ||||
|   let tokens = Lexer.lex_tokens_seq content in  | ||||
|   let expr = Parser.get_expr_tree_from_tokens tokens in | ||||
|   match expr with | ||||
|   | Some e -> typecheck_result e | ||||
|   | None -> Result.Error (Failure "parse error") | ||||
| 
 | ||||
| let%test "typecheck 1" =  | ||||
|   let expr =  "let x = fun y -> y in x 1" in | ||||
|   match test_typecheck expr with | ||||
|   | Result.Ok (t) -> Printf.printf "%s\n" (type_v2str t); t = Int | ||||
|   | Result.Error _ -> Printf.printf "error\n"; false | ||||
		Loading…
	
	Add table
		
		Reference in a new issue