import qualified Table import EnvSimple hiding (T) import qualified EnvSimple(T) import Parser import Repl type Symbol = String data Value = VNum Integer | VCtor Symbol Int [Value] | VRec (Table.T Symbol Value) | VFun Symbol Environment Expr data Type = TInt | TName Symbol | TFun Type Type | TVar Symbol (Table.T Symbol [Type]) | TRec (Table.T Symbol Type) deriving Eq data Expr = EId Symbol | ENum Integer | EPlus Expr Expr | ETimes Expr Expr | EFun Symbol Type Expr | EApp Expr Expr | ELet Symbol Type Expr Expr | ERecord [(Symbol, Expr)] | EMember Expr Symbol | ETypeVar Symbol [(Symbol, [Type])] Expr | ETypeRec Symbol [(Symbol, Type)] Expr | ETypeDef Symbol Type Expr | ECase Expr [(Pattern, Expr)] deriving Eq data Pattern = ElsePat | VarPat Symbol | CtorPat Symbol [Symbol] | NumPat Integer deriving Eq type Environment = EnvSimple.T Value type Declarations = Table.T Symbol Type uncurry_type :: Type -> ([Type], Type) uncurry_type (TFun a b) = let (arg_types, ret_type) = uncurry_type b in (a:arg_types, ret_type) uncurry_type a = ([], a) equal_types types t1 t2 = case (t1, t2) of (TName n, _) -> equal_types types (Table.lookup types n) t2 (_, TName n) -> equal_types types t1 (Table.lookup types n) (TInt, TInt) -> True (TFun a1 b1, TFun a2 b2) -> equal_types types a1 a2 && equal_types types b1 b2 (TVar n1 _, TVar n2 _) -> n1 == n2 (TRec f1, TRec f2) -> Table.fold2 (\f s1 s2 acc -> if not acc then False else equal_types types s1 s2) (Table.same_keys f1 f2) f1 f2 _ -> False type_check :: Declarations -> Declarations -> Expr -> Type type_check types decl (ENum n) = TInt type_check types decl (EId id) = Table.lookup decl id type_check types decl (EFun x t body) = TFun t (type_check types (Table.bind decl x t) body) type_check types decl (EPlus e1 e2) = if type_check types decl e1 == TInt && type_check types decl e2 == TInt then TInt else error "type error" type_check types decl (ETimes e1 e2) = if type_check types decl e1 == TInt && type_check types decl e2 == TInt then TInt else error "type error" type_check types decl (EApp f a) = let at = type_check types decl a in case type_check types decl f of TFun t1 t2 -> if equal_types types t1 at then t2 else error "type error" _ -> error "type error" type_check types decl (ELet x t e b) = let new_decl = Table.bind decl x t in if equal_types types t (type_check types new_decl e) then type_check types new_decl b else error "type error " type_check types decl (ERecord fields) = let field_types = foldl (\table (f,e) -> Table.bind table f (type_check types decl e)) Table.empty fields in TRec field_types type_check types decl (EMember e f) = case type_check types decl e of TRec fields -> Table.lookup fields f _ -> error "type error" type_check types decl (ETypeVar name variants e) = let ctors = foldl (\table (c,args) -> Table.bind table c args) Table.empty variants in let typ = TVar name ctors in let new_types = Table.bind types name typ in let new_decl = foldl (\decl (c,args) -> Table.bind decl c (foldr (\t a -> TFun a t) typ args)) Table.empty variants in type_check new_types new_decl e type_check types decl (ETypeRec name fields e) = let field_types = foldl (\table (f,t) -> Table.bind table f t) Table.empty fields in let new_types = Table.bind types name (TRec field_types) in type_check new_types decl e type_check types decl (ETypeDef name t e) = let new_types = Table.bind types name t in type_check new_types decl e type_check types decl (ECase e pats) = case map check_pat pats of [] -> error "malformed case statement" (p:ps) -> if all (equal_types types p) ps then p else error "type error" where typ = type_check types decl e check_guard ElsePat = (True, decl) check_guard (VarPat x) = (True, Table.bind decl x typ) check_guard (NumPat n) = (equal_types types typ TInt, decl) check_guard (CtorPat c xs) = let (arg_typ, ret_typ) = uncurry_type (Table.lookup decl c) in (equal_types types typ ret_typ && length xs == length arg_typ, foldl (\decl (x,t) -> Table.bind decl x t) decl (zip xs arg_typ)) check_pat (g, b) = case check_guard g of (True, new_decl) -> type_check types new_decl b (False, _) -> error "type error" eval :: Environment -> Expr -> Value eval env (ENum n) = VNum n eval env (EFun arg _ body) = VFun arg env body eval env (EId id) = lookup_variable env id eval env (EPlus e1 e2) = case (eval env e1, eval env e2) of (VNum x1, VNum x2) -> VNum (x1 + x2) _ -> error "addition of non-numbers" eval env (ETimes e1 e2) = case (eval env e1, eval env e2) of (VNum x1, VNum x2) -> VNum (x1 * x2) _ -> error "multiplication of non-numbers" eval env (EApp f a) = let val = eval env a in case eval env f of VFun p e b -> eval (bind_variable e p val) b VCtor ctor 0 args -> error "constructor applied to too many arguments" VCtor ctor ar args -> VCtor ctor (ar-1) (args ++ [val]) _ -> error "application to non-function" eval env (ERecord fields) = VRec (map (\(f,e) -> (f, eval env e)) fields) eval env (EMember e f) = case eval env e of VRec fields -> Table.lookup fields f _ -> error "not a record" eval env (ELet x _ e b) = eval new_env b where new_env = bind_variable env x v v = eval new_env e eval env (ETypeVar _ ctors e) = eval new_env e where new_env = foldl (\env (c,args) -> bind_variable env c (VCtor c (length args) [])) env ctors eval env (ETypeRec _ _ e) = eval env e eval env (ETypeDef _ _ e) = eval env e eval env (ECase e pats) = iter pats where val = eval env e iter [] = error "no matching pattern in case expression" iter ((ElsePat, e):ps) = eval env e iter ((VarPat x, e):ps) = case maybe_lookup_variable env x of -- check whether x is a constructor or a variable Nothing -> eval (bind_variable env x val) e Just (VCtor c 0 []) -> case val of VCtor d 0 [] | c == d -> eval env e _ -> iter ps iter ((NumPat n, e):ps) = case val of VNum k | n == k -> eval env e _ -> iter ps iter ((CtorPat c xs,e):ps) = case val of VCtor d 0 args | c == d && length xs == length args -> eval (foldl (\env (x,a) -> bind_variable env x a) env (zip xs args)) e _ -> iter ps desugar :: PExpr -> Expr desugar (PId id) = EId id desugar (PNum n) = ENum n desugar (PPlus e1 e2) = EPlus (desugar e1) (desugar e2) desugar (PTimes e1 e2) = ETimes (desugar e1) (desugar e2) desugar (PMinus e1 e2) = EPlus (desugar e1) (ETimes (ENum (-1)) (desugar e2)) desugar (PEqual e1 e2) = ECase (desugar (PMinus e1 e2)) [ (NumPat 0, EId "%True"), (ElsePat, EId "%False") ] desugar (PFun [] b) = EFun "()" (TVar "unit" []) (desugar b) desugar (PFun [(p,Just t)] b) = EFun p (desugar_type t) (desugar b) desugar (PFun ((p,Just t):args) b) = EFun p (desugar_type t) (desugar (PFun args b)) desugar (PApp f []) = EApp (desugar f) (EId "()") desugar (PApp f [a]) = EApp (desugar f) (desugar a) desugar (PApp f args) = foldl (\f a -> EApp f a ) (desugar f) (map desugar args) desugar (PRecord fields) = ERecord (map (\(f,e) -> (f, desugar e)) fields) desugar (PMember e f) = EMember (desugar e) f desugar (PDecl (PLet x (Just t) e) b) = ELet x (desugar_type t) (desugar e) (desugar b) desugar (PDecl (PLetFun f args (Just t) b) e) = ELet f (desugar_type ft) (desugar (PFun args b)) (desugar e) where ft = foldr (\a b -> (PArrow a b)) t (map (\(_, Just t) -> t) args) desugar (PDecl (PTypeVar t [] ctors) e) = ETypeVar t (map (\(c,args) -> (c, map desugar_type args)) ctors) (desugar e) desugar (PDecl (PTypeRec t [] fields) e) = ETypeRec t (map (\(f,t) -> (f, desugar_type t)) fields) (desugar e) desugar (PDecl (PTypeDef t [] t2) e) = ETypeDef t (desugar_type t2) (desugar e) desugar (PIf c t e) = ECase (desugar c) [ (VarPat "%True", desugar t), (ElsePat, desugar e) ] desugar (PCase e pats) = ECase (desugar e) (map (\(p,e) -> (desugar_pat p, desugar e)) pats) desugar (PListLit es t) = foldr (\x y -> (EApp (EApp (EId "%Cons") (desugar x)) y)) (case t of Nothing -> EId "%Nil" Just e -> desugar e) es desugar _ = error "unsupported syntactic construct" desugar_pat :: PPattern -> Pattern desugar_pat PElsePat = ElsePat desugar_pat (PVarPat x) = VarPat x desugar_pat (PCtorPat c xs) = CtorPat c xs desugar_pat (PNumPat n) = NumPat n desugar_type :: PExpr -> Type desugar_type (PId "int") = TInt desugar_type (PId n) = TName n -- desugar_type (PApp t a) = case desugar_type t of -- TVar c args -> TVar c (args ++ (map desugar_type a)) -- _ -> error "invalid type annotation" desugar_type (PArrow a b) = TFun (desugar_type a) (desugar_type b) desugar_type _ = error "invalid type annotation" show_list :: Show a => String -> String -> (a -> String) -> [a] -> String show_list l r sh [] = "" show_list l r sh xs = l ++ foldl1 (\ x y -> x ++ ", " ++ y) (map sh xs) ++ r instance Show Value where show (VNum n) = show n show (VCtor c _ args) = c ++ show_list "(" ")" show args show (VRec fields) = show_list "[" "]" (\(f,v) -> f ++ " = " ++ show v) (Table.to_list fields) show (VFun a _ b) = "fun (" ++ a ++ ") { " ++ show b ++ " }" instance Show Type where show TInt = "int" show (TName n) = n show (TFun a b) = "(" ++ show a ++ " -> " ++ show b ++ ")" show (TVar c args) = c ++ show_list "(" ")" show args show (TRec fields) = show_list "[" "]" (\(f,t) -> f ++ " : " ++ show t) (Table.to_list fields) instance Show Expr where show (EId x) = x show (ENum n) = show n show (EPlus e1 e2) = "(" ++ show e1 ++ " + " ++ show e2 ++ ")" show (ETimes e1 e2) = "(" ++ show e1 ++ " * " ++ show e2 ++ ")" show (EFun a t b) = "fun (" ++ a ++ " : " ++ show t ++ ") { " ++ show b ++ " }" show (EApp f a) = show fn ++ show_list "(" ")" show args where (fn, args) = collect_args f [a] collect_args (EApp g b) args = collect_args g (b:args) collect_args g args = (g, args) show (ERecord fields) = show_list "[" "]" (\(f,e) -> f ++ " = " ++ show e) fields show (EMember e f) = show e ++ "." ++ f show (ELet x t b e) = "let " ++ x ++ " :" ++ show t ++ " = " ++ show b ++ "; " ++ show e show (ECase e pats) = foldl (++) ("case " ++ show e) (map (\(p,e) -> " | " ++ show p ++ " => " ++ show e) pats) show (ETypeVar t ctors e) = foldl (++) ("type " ++ t ++ " =") (map (\(c,args) -> " | " ++ c ++ show_list "(" ")" show args) ctors) ++ "; " ++ show e show (ETypeRec t fields e) = "type " ++ t ++ " = {" ++ show_list "(" ")" (\(c,t) -> c ++ " : " ++ show t) fields ++ "}; " ++ show e instance Show Pattern where show ElsePat = "else" show (VarPat x) = x show (NumPat n) = show n show (CtorPat c xs) = c ++ show_list "(" ")" show xs type_unit = TVar "unit" Table.empty type_bool = TVar "bool" (Table.bind (Table.bind Table.empty "True" []) "False" []) type_pair = TVar "product" (Table.bind Table.empty "pair" [TInt, TInt]) type_list = TVar "list" (Table.bind (Table.bind Table.empty "Nil" []) "Cons" [TInt, type_list]) builtin_types = foldl (\types (n,t) -> Table.bind types n t) Table.empty [("unit", type_unit), ("bool", type_bool), ("product", type_pair), ("list", type_list)] (builtin_env, builtin_decls) = foldl (\(env, decls) (ctor, val, typ) -> (bind_variable env ctor val, Table.bind decls ctor typ)) (empty_env, Table.empty) [("()", VCtor "()" 0 [], type_unit), ("True", VCtor "True" 0 [], type_bool), ("False", VCtor "False" 0 [], type_bool), ("Pair", VCtor "Pair" 2 [], TFun TInt (TFun TInt type_pair)), ("Nil", VCtor "Nil" 0 [], type_list), ("Cons", VCtor "Cons" 2 [], TFun TInt (TFun type_list type_list)), ("%True", VCtor "True" 0 [], type_bool), ("%False", VCtor "False" 0 [], type_bool), ("%Nil", VCtor "Nil" 0 [], type_list), ("%Cons", VCtor "Cons" 2 [], TFun TInt (TFun type_list type_list))] run str = eval builtin_env (desugar (parse str)) run_tc str = type_check builtin_types builtin_decls (desugar (parse str)) main :: IO () main = repl (show . run) (Just (show . run_tc))