import Debug.Trace import Control.Monad import Control.Monad.Trans.State.Strict import qualified Table import EnvSimple hiding (T) import qualified EnvSimple(T) import Parser import Repl type Symbol = String data Value = VNum Integer | VCtor Symbol Int [Value] | VRec (Table.T Symbol Value) | VFun Symbol Environment Expr data Type = TParam Symbol | TInt | TBool | TFun Type Type | TRec (Table.T Symbol Type) | TVar (Table.T Symbol [Type]) | TMu Type | TArg Int | TApp Type Type deriving Eq data Expr = EId Symbol | ENum Integer | EPlus Expr Expr | ETimes Expr Expr | EFun Symbol (Maybe Type) Expr | EApp Expr Expr | ELet Symbol (Maybe Type) Expr Expr | ERecord [(Symbol, Expr)] | EMember Expr Symbol | ETypeVar Symbol [(Symbol, [Type])] Expr | ETypeRec Symbol [(Symbol, Type)] Expr | ETypeDef Symbol Type Expr | ECase Expr [(Pattern, Expr)] deriving Eq data Pattern = ElsePat | VarPat Symbol | CtorPat Symbol [Symbol] | NumPat Integer deriving Eq type Environment = EnvSimple.T Value type Declarations = Table.T Symbol (Type, Bool) uncurry_type :: Type -> ([Type], Type) uncurry_type (TFun a b) = let (arg_types, ret_type) = uncurry_type b in (a:arg_types, ret_type) uncurry_type a = ([], a) collect_params :: Type -> [Symbol] collect_params (TParam x) = [x] collect_params TInt = [] collect_params TBool = [] collect_params (TFun a b) = collect_params a ++ collect_params b collect_params (TRec fields) = Table.fold (\_ a acc -> collect_params a ++ acc) [] fields collect_params (TVar vars) = Table.fold (\_ args acc -> foldl (\acc a -> collect_params a ++ acc) acc args) [] vars collect_params (TMu a) = collect_params a collect_params (TArg n) = [] collect_params (TApp a b) = collect_params a ++ collect_params b rename_params :: String -> Type -> Type rename_params suf (TParam x) = TParam (x ++ suf) rename_params suf TInt = TInt rename_params suf TBool = TBool rename_params suf (TFun a b) = TFun (rename_params suf a) (rename_params suf b) rename_params suf (TRec fields) = TRec (Table.map (\_ e -> rename_params suf e) fields) rename_params suf (TVar vars) = TVar (Table.map (\_ args -> map (rename_params suf) args) vars) rename_params suf (TMu a) = rename_params suf a rename_params suf (TArg n) = TArg n rename_params suf (TApp a b) = TApp (rename_params suf a) (rename_params suf b) rename_params_in_expr :: String -> Expr -> Expr rename_params_in_expr suf (EId x) = EId x rename_params_in_expr suf (ENum n) = ENum n rename_params_in_expr suf (EPlus x y) = EPlus (rename_params_in_expr suf x) (rename_params_in_expr suf y) rename_params_in_expr suf (ETimes x y) = ETimes (rename_params_in_expr suf x) (rename_params_in_expr suf y) rename_params_in_expr suf (EFun x t b) = EFun x (fmap (rename_params suf) t) (rename_params_in_expr suf b) rename_params_in_expr suf (EApp f x) = EApp (rename_params_in_expr suf f) (rename_params_in_expr suf x) rename_params_in_expr suf (ELet x t b e) = ELet x (fmap (rename_params suf) t) (rename_params_in_expr suf b) (rename_params_in_expr suf e) rename_params_in_expr suf (ERecord fields) = ERecord (Table.map (\_ e -> rename_params_in_expr suf e) fields) rename_params_in_expr suf (EMember e f) = EMember (rename_params_in_expr suf e) f rename_params_in_expr suf (ETypeVar n vars e) = ETypeVar n vars (rename_params_in_expr suf e) -- XXX modify vars rename_params_in_expr suf (ETypeRec n fields e) = ETypeRec n fields (rename_params_in_expr suf e) rename_params_in_expr suf (ETypeDef n t e) = ETypeDef n (rename_params suf t) (rename_params_in_expr suf e) rename_params_in_expr suf (ECase e pats) = ECase (rename_params_in_expr suf e) (map (\(p,e) -> (p, rename_params_in_expr suf e)) pats) {- equal_types types t1 t2 = case (t1, t2) of (TName n, _) -> equal_types types (Table.lookup types n) t2 (_, TName n) -> equal_types types t1 (Table.lookup types n) (TInt, TInt) -> True (TFun a1 b1, TFun a2 b2) -> equal_types types a1 a2 && equal_types types b1 b2 (TVar n1 _, TVar n2 _) -> n1 == n2 (TRec f1, TRec f2) -> True -- XXX -} type VarTable = (Int, Table.T Symbol Type) data UnifyData = UD VarTable [Type] [Type] [(Type,Type)] lookup_tvar :: Symbol -> State UnifyData (Maybe Type) lookup_tvar x = do UD (_, tab) _ _ _ <- get return (Table.maybe_lookup tab x) bind_tvar :: Symbol -> Type -> State UnifyData () bind_tvar x t = do UD (n, tab) lm rm vis <- get trace ("binding " ++ x ++ " : " ++ show t) $ put (UD (n, Table.bind tab x t) lm rm vis) was_visited :: Type -> Type -> State UnifyData Bool was_visited a b = do UD var lm rm vis <- get if elem (a, b) vis then return True else do put (UD var lm rm ((a,b):vis)) return False lookup_mu_left :: Int -> State UnifyData Type lookup_mu_left n = do UD _ lm _ _ <- get return (lm !! n) lookup_mu_right :: Int -> State UnifyData Type lookup_mu_right n = do UD _ _ rm _ <- get return (rm !! n) push_mu_left :: Type -> State UnifyData () push_mu_left t = do UD var lm rm vis <- get put (UD var (t:lm) rm vis) push_mu_right :: Type -> State UnifyData () push_mu_right t = do UD var lm rm vis <- get put (UD var lm (t:rm) vis) unify :: Type -> Type -> State UnifyData Bool unify t1 t2 = trace ("unifying " ++ show t1 ++ " and " ++ show t2) $ do seen <- was_visited t1 t2 if seen then return True else unify_rec t1 t2 unify_rec (TParam x1) (TParam x2) = do v1 <- lookup_tvar x1 v2 <- lookup_tvar x2 case (v1, v2) of (Nothing, Nothing) -> do unless (x1 == x2) (bind_tvar x1 (TParam x2)) return True (Nothing, Just t) -> do { bind_tvar x1 t; return True } (Just t, Nothing) -> do { bind_tvar x2 t; return True } (Just s, Just t) -> unify s t unify_rec (TParam x) t = do v <- lookup_tvar x case v of Nothing -> do { bind_tvar x t; return True } Just s -> unify s t unify_rec t (TParam x) = do v <- lookup_tvar x case v of Nothing -> do { bind_tvar x t; return True } Just s -> unify t s unify_rec TInt TInt = return True unify_rec TBool TBool = return True unify_rec (TFun a1 b1) (TFun a2 b2) = do m1 <- unify a1 a2 m2 <- unify b1 b2 return (m1 && m2) unify_rec (TRec fields1) (TRec fields2) = Table.fold2 (\_ a b un -> do match <- un if not match then return False else unify a b) (return True) fields1 fields2 unify_rec (TVar vars1) (TVar vars2) = Table.fold2 (\_ args1 args2 un -> do match <- un if not match || length args1 /= length args2 then return False else foldl (\un (a, b) -> do match <- un if not match then return False else unify a b) (return True) (zip args1 args2)) (return True) vars1 vars2 unify_rec (TMu a) (TMu b) = do push_mu_left a push_mu_right b unify_rec a b unify_rec (TMu a) t = do push_mu_left a unify_rec a t unify_rec t (TMu a) = do push_mu_right a unify_rec t a unify_rec t (TArg n) = do a <- lookup_mu_right n unify t a unify_rec (TArg n) t = do a <- lookup_mu_left n unify a t unify_rec (TApp a1 b1) t = return False unify_rec _ _ = return False alloc_var :: State VarTable Int alloc_var = do (n, ct) <- get put (n+1, ct) return n new_var_name :: State VarTable String new_var_name = do n <- alloc_var return ("t " ++ show n) {- type_check :: Expr -> Type type_check expr = evalState (check Table.empty expr >>= unpack_type) (0, Table.empty) where unpack_type (TVar x) = do typ <- lookup_tvar x case typ of Just t -> unpack_type t Nothing -> return (TVar x) unpack_type TInt = return TInt unpack_type TBool = return TBool unpack_type (TFun a b) = do a2 <- unpack_type a b2 <- unpack_type b return (TFun a2 b2) unpack_type (TProduct args) = do args2 <- mapM unpack_type args return (TProduct args2) unpack_type (TList a) = do a2 <- unpack_type a return (TList a2) check decl expr = do (_,tab) <- get trace ("type checking " ++ show decl ++ " " ++ show expr ++ " " ++ show tab) $ tcheck decl expr >>= \typ -> (get >>= \(_,tab) -> trace ("result: " ++ show decl ++ " " ++ show expr ++ " : " ++ show typ ++ " " ++ show tab) $ return typ) check :: Declarations -> Expr -> State VarTable Type tcheck decl (ECtor (CNum n) []) = return TInt tcheck decl (ECtor CTrue []) = return TBool tcheck decl (ECtor CFalse []) = return TBool tcheck decl (ECtor CTuple es) = do ts <- mapM (check decl) es return (TProduct ts) tcheck decl (ECtor CNil []) = do x <- new_var_name return (TList (TVar x)) tcheck decl (ECtor CCons [h,t]) = do ht <- check decl h tt <- check decl t case tt of TList typ -> do match <- unify ht typ unless match (error "type error") return tt _ -> error "type error" tcheck decl (ECtor _ _) = error "type error" tcheck decl (EId id) = do let (t, generalise) = Table.lookup decl id if generalise then do n <- alloc_var return (rename_params (" " ++ show n) t) else return t tcheck decl (EFun x t body) = do bt <- check (Table.bind decl x (t, False)) body return (TFun t bt) tcheck decl (EPlus e1 e2) = do t1 <- check decl e1 t2 <- check decl e2 m1 <- unify t1 TInt m2 <- unify t2 TInt unless (m1 && m2) (error "type error") return TInt tcheck decl (ETimes e1 e2) = do t1 <- check decl e1 t2 <- check decl e2 m1 <- unify t1 TInt m2 <- unify t2 TInt unless (m1 && m2) (error "type error") return TInt tcheck decl (EApp f a) = do at <- check decl a ft <- check decl f case ft of TFun t1 t2 -> do match <- unify t1 at unless match (error "type error") return t2 _ -> error "type error" tcheck decl (ELet x t e b) = do et <- check (Table.bind decl x (t, False)) e match <- unify t et unless match (error "type error") bt <- check (Table.bind decl x (t, True)) b return bt tcheck decl (EIf c1 c2 t e) = do t1 <- check decl c1 t2 <- check decl c2 tt <- check decl t et <- check decl e m1 <- unify t1 t2 m2 <- unify tt et unless (m1 && m2) (error "type error") return tt tcheck decl (ECase e pats) = do typ <- check decl e ps <- mapM (check_pat typ) pats match <- unify_all ps unless match (error "type error") return (head ps) where unify_all (p1:p2:ps) = do m1 <- unify p1 p2 m2 <- unify_all (p2:ps) return (m1 && m2) unify_all [p] = return True unify_all [] = return True check_pat typ (g, b) = do (m, new_decl) <- check_guard typ g unless m (error "type error") check new_decl b check_guard typ ElsePat = return (True, decl) check_guard typ (VarPat x t) = do m <- unify typ t return (m, Table.bind decl x (t, False)) check_guard typ (CtorPat c xs) = do ct <- ctor_type c (map snd xs) m <- unify typ ct return (m, foldl (\decl (x,t) -> Table.bind decl x (t, False)) decl xs) ctor_type (CNum _) [] = return (TInt) ctor_type CTrue [] = return (TBool) ctor_type CFalse [] = return (TBool) ctor_type CTuple args = return (TProduct args) ctor_type CCons [h,t] = do m <- unify t (TList h) unless m (error "type error") return t ctor_type CNil [] = return (TList TInt) -} eval :: Environment -> Expr -> Value eval env (ENum n) = VNum n eval env (EFun arg _ body) = VFun arg env body eval env (EId id) = lookup_variable env id eval env (EPlus e1 e2) = case (eval env e1, eval env e2) of (VNum x1, VNum x2) -> VNum (x1 + x2) _ -> error "addition of non-numbers" eval env (ETimes e1 e2) = case (eval env e1, eval env e2) of (VNum x1, VNum x2) -> VNum (x1 * x2) _ -> error "multiplication of non-numbers" eval env (EApp f a) = let val = eval env a in case eval env f of VFun p e b -> eval (bind_variable e p val) b VCtor ctor 0 args -> error "constructor applied to too many arguments" VCtor ctor ar args -> VCtor ctor (ar-1) (args ++ [val]) _ -> error "application to non-function" eval env (ERecord fields) = VRec (map (\(f,e) -> (f, eval env e)) fields) eval env (EMember e f) = case eval env e of VRec fields -> Table.lookup fields f _ -> error "not a record" eval env (ELet x _ e b) = eval new_env b where new_env = bind_variable env x v v = eval new_env e eval env (ETypeVar _ ctors e) = eval new_env e where new_env = foldl (\env (c,args) -> bind_variable env c (VCtor c (length args) [])) env ctors eval env (ETypeRec _ _ e) = eval env e eval env (ETypeDef _ _ e) = eval env e eval env (ECase e pats) = iter pats where val = eval env e iter [] = error "no matching pattern in case expression" iter ((ElsePat, e):ps) = eval env e iter ((VarPat x, e):ps) = case maybe_lookup_variable env x of -- check whether x is a constructor or a variable Nothing -> eval (bind_variable env x val) e Just (VCtor c 0 []) -> case val of VCtor d 0 [] | c == d -> eval env e _ -> iter ps iter ((NumPat n, e):ps) = case val of VNum k | n == k -> eval env e _ -> iter ps iter ((CtorPat c xs,e):ps) = case val of VCtor d 0 args | c == d && length xs == length args -> eval (foldl (\env (x,a) -> bind_variable env x a) env (zip xs args)) e _ -> iter ps desugar :: PExpr -> Expr desugar (PId id) = EId id desugar (PNum n) = ENum n desugar (PPlus e1 e2) = EPlus (desugar e1) (desugar e2) desugar (PTimes e1 e2) = ETimes (desugar e1) (desugar e2) desugar (PMinus e1 e2) = EPlus (desugar e1) (ETimes (ENum (-1)) (desugar e2)) desugar (PEqual e1 e2) = ECase (desugar (PMinus e1 e2)) [ (NumPat 0, EId "%True"), (ElsePat, EId "%False") ] desugar (PFun [] b) = EFun "()" Nothing (desugar b) desugar (PFun [(p,t)] b) = EFun p (fmap desugar_type t) (desugar b) desugar (PFun ((p,t):args) b) = EFun p (fmap desugar_type t) (desugar (PFun args b)) desugar (PApp f []) = EApp (desugar f) (EId "()") desugar (PApp f [a]) = EApp (desugar f) (desugar a) desugar (PApp f args) = foldl (\f a -> EApp f a ) (desugar f) (map desugar args) desugar (PRecord fields) = ERecord (map (\(f,e) -> (f, desugar e)) fields) desugar (PMember e f) = EMember (desugar e) f desugar (PDecl (PLet x t e) b) = ELet x (fmap desugar_type t) (desugar e) (desugar b) desugar (PDecl (PLetFun f args t b) e) = case t of Nothing -> ELet f Nothing (desugar (PFun args b)) (desugar e) Just t -> ELet f (Just (desugar_type ft)) (desugar (PFun args b)) (desugar e) where ft = foldr (\a b -> (PArrow a b)) t (map (\(_, Just t) -> t) args) desugar (PDecl (PTypeVar t [] ctors) e) = ETypeVar t (map (\(c,args) -> (c, map desugar_type args)) ctors) (desugar e) desugar (PDecl (PTypeRec t [] fields) e) = ETypeRec t (map (\(f,t) -> (f, desugar_type t)) fields) (desugar e) desugar (PDecl (PTypeDef t [] t2) e) = ETypeDef t (desugar_type t2) (desugar e) desugar (PIf c t e) = ECase (desugar c) [ (VarPat "%True", desugar t), (ElsePat, desugar e) ] desugar (PCase e pats) = ECase (desugar e) (map (\(p,e) -> (desugar_pat p, desugar e)) pats) desugar (PListLit es t) = foldr (\x y -> (EApp (EApp (EId "%Cons") (desugar x)) y)) (case t of Nothing -> EId "%Nil" Just e -> desugar e) es desugar _ = error "unsupported syntactic construct" desugar_pat :: PPattern -> Pattern desugar_pat PElsePat = ElsePat desugar_pat (PVarPat x) = VarPat x desugar_pat (PCtorPat c xs) = CtorPat c xs desugar_pat (PNumPat n) = NumPat n desugar_type :: PExpr -> Type desugar_type (PId "int") = TInt desugar_type (PId n) = TParam n -- desugar_type (PApp t a) = case desugar_type t of -- TVar c args -> TVar c (args ++ (map desugar_type a)) -- _ -> error "invalid type annotation" desugar_type (PArrow a b) = TFun (desugar_type a) (desugar_type b) desugar_type _ = error "invalid type annotation" show_list :: Show a => String -> String -> (a -> String) -> [a] -> String show_list l r sh [] = "" show_list l r sh xs = l ++ foldl1 (\ x y -> x ++ ", " ++ y) (map sh xs) ++ r instance Show Value where show (VNum n) = show n show (VCtor c _ args) = c ++ show_list "(" ")" show args show (VRec fields) = show_list "[" "]" (\(f,v) -> f ++ " = " ++ show v) (Table.to_list fields) show (VFun a _ b) = "fun (" ++ a ++ ") { " ++ show b ++ " }" instance Show Type where show TInt = "int" show (TParam n) = n show (TFun a b) = "(" ++ show a ++ " -> " ++ show b ++ ")" show (TVar ctors) = Table.fold (\ c args str -> str ++ "| " ++ c ++ show_list "(" ")" show args) "" ctors show (TRec fields) = show_list "[" "]" (\(f,t) -> f ++ " : " ++ show t) (Table.to_list fields) instance Show Expr where show (EId x) = x show (ENum n) = show n show (EPlus e1 e2) = "(" ++ show e1 ++ " + " ++ show e2 ++ ")" show (ETimes e1 e2) = "(" ++ show e1 ++ " * " ++ show e2 ++ ")" show (EFun a t b) = "fun (" ++ a ++ " : " ++ show t ++ ") { " ++ show b ++ " }" show (EApp f a) = show fn ++ show_list "(" ")" show args where (fn, args) = collect_args f [a] collect_args (EApp g b) args = collect_args g (b:args) collect_args g args = (g, args) show (ERecord fields) = show_list "[" "]" (\(f,e) -> f ++ " = " ++ show e) fields show (EMember e f) = show e ++ "." ++ f show (ELet x t b e) = "let " ++ x ++ " :" ++ show t ++ " = " ++ show b ++ "; " ++ show e show (ECase e pats) = foldl (++) ("case " ++ show e) (map (\(p,e) -> " | " ++ show p ++ " => " ++ show e) pats) show (ETypeVar t ctors e) = foldl (++) ("type " ++ t ++ " =") (map (\(c,args) -> " | " ++ c ++ show_list "(" ")" show args) ctors) ++ "; " ++ show e show (ETypeRec t fields e) = "type " ++ t ++ " = {" ++ show_list "(" ")" (\(c,t) -> c ++ " : " ++ show t) fields ++ "}; " ++ show e instance Show Pattern where show ElsePat = "else" show (VarPat x) = x show (NumPat n) = show n show (CtorPat c xs) = c ++ show_list "(" ")" show xs type_unit = TVar (Table.bind Table.empty "()" []) type_bool = TVar (Table.bind (Table.bind Table.empty "True" []) "False" []) type_pair = TVar (Table.bind Table.empty "pair" [TInt, TInt]) type_list = TVar (Table.bind (Table.bind Table.empty "Nil" []) "Cons" [TInt, type_list]) builtin_types = foldl (\types (n,t) -> Table.bind types n t) Table.empty [("unit", type_unit), ("bool", type_bool), ("product", type_pair), ("list", type_list)] (builtin_env, builtin_decls) = foldl (\(env, decls) (ctor, val, typ) -> (bind_variable env ctor val, Table.bind decls ctor typ)) (empty_env, Table.empty) [("()", VCtor "()" 0 [], type_unit), ("True", VCtor "True" 0 [], type_bool), ("False", VCtor "False" 0 [], type_bool), ("Pair", VCtor "Pair" 2 [], TFun TInt (TFun TInt type_pair)), ("Nil", VCtor "Nil" 0 [], type_list), ("Cons", VCtor "Cons" 2 [], TFun TInt (TFun type_list type_list)), ("%True", VCtor "True" 0 [], type_bool), ("%False", VCtor "False" 0 [], type_bool), ("%Nil", VCtor "Nil" 0 [], type_list), ("%Cons", VCtor "Cons" 2 [], TFun TInt (TFun type_list type_list))] run str = eval builtin_env (desugar (parse str)) run_tc str = type_check builtin_types builtin_decls (desugar (parse str)) main :: IO () main = repl (show . run) (Just (show . run_tc))