{-
    Kaya - My favourite toy language.
    Copyright (C) 2004, 2005 Edwin Brady

    This file is distributed under the terms of the GNU General
    Public Licence. See COPYING for licence.
-}

module InfGadgets where

-- Helper functions for type inference

import Language

-- Substitutions and unification (from SPJ87)

type Subst = Name -> Type

-- Apply a substitution to a type
subst :: Subst -> Type -> Type
subst s (Prim t) = Prim t
subst s (Fn ns ts t) = Fn ns (map (subst s) ts) (subst s t)
subst s (Array t) = Array (subst s t)
subst s (User n ts) = User n (map (subst s) ts)
--subst s (Syn n) = Syn n
subst s (TyVar n) = s n
subst s x = x

-- Substitution composition
scomp :: Subst -> Subst -> Subst
scomp s2 s1 tn = subst s2 (s1 tn)

id_subst :: Subst
id_subst tn = TyVar tn

delta :: Name -> Type -> Subst
delta tn t tn' | tn == tn' = t
	       | otherwise = TyVar tn'

-- Extend a substitution with a new one, or fail if there's an error
extend :: Monad m => String -> Int -> Subst -> Name -> Type -> m Subst
extend file line phi tvn (TyVar n) | tvn == n = return phi
extend file line phi tvn t | tvn `elem` (getVars t) = fail $ file++":"++show line++":Unification error - possible infinite type"
			   | otherwise = return {- $ trace ("Extending with " ++ show tvn ++ " -> " ++ show t)-} $ 
					 (scomp $! (delta tvn t)) $! phi

unify :: Monad m => Subst -> (Type, Type, String, Int) -> m Subst
unify phi e@(t1,t2,f,l) = {- trace ("Unifying " ++ show t1 ++ " & " ++ show t2) $ -}
			  unify' phi e
unify' phi ((TyVar tvn),t,f,l) 
    | phitvn == (TyVar tvn) = extend f l phi tvn phit
    | otherwise = unify phi (phitvn,phit,f,l)
   where phitvn = phi tvn
	 phit = subst phi t
unify' phi ((Array t),(Array t'),f,l) = unify phi (t,t',f,l)
unify' phi (t1@(Fn ns ts t),t2@(Fn ns' ts' t'),f,l) 
    = do zls <- (zipfl (t:ts) (t':ts') f l err)
	 unifyl phi zls
  where err = f ++ ":" ++ show l ++ ":Can't unify " ++ show t1 ++ " and "
	      ++ show t2
unify' phi (t1@(User n ts),t2@(User n' ts'),f,l) 
   | n == n' && (length ts == length ts') = do zl <- (zipfl (ts) (ts') f l err)
					       unifyl phi zl
   | otherwise = fail $ err
  where err = f ++ ":"++ show l ++ ":Can't unify " ++ 
	      show t1 ++ " and " ++ show t2
-- Try it the other way...
unify' phi (t,(TyVar tvn),f,l) = unify phi ((TyVar tvn),t,f,l)
-- And now we must have something primitive
unify' phi (t,t',f,l) | t == t' = return phi
		      | otherwise = fail $ f ++ ":" ++ 
				          show l ++ ":Can't unify " ++ 
					  show t ++ " and " ++ show t'

zipfl :: Monad m => [a] -> [b] -> c -> d -> String -> m [(a,b,c,d)]
zipfl [] [] _ _ err = return []
zipfl (x:xs) (y:ys) z w err = do zl <- zipfl xs ys z w err
				 return $ (x,y,z,w):zl
zipfl _ _ _ _ err = fail err

unifyl :: Monad m => Subst -> [(Type,Type,String,Int)] -> m Subst
unifyl phi [] = return phi
unifyl phi (x:xs) = do phi' <- unify phi x
		       unifyl phi' xs

getpos n xs = getpos' n 0 (-1) xs
getpos' n _ last [] = last
getpos' n i last ((x,(t,_)):xs) | n==x = getpos' n (i+1) i xs
	  			| otherwise = getpos' n (i+1) last xs

-- Convert the global names (the Ps) to local variable indexes (Vs)
-- (The name is a reference to McKinna-Pollack '91. Apologies...)

pToV :: Context -> Expr Name -> Expr Name
pToV cs (Global n m ar) | getpos n cs >= 0 = (Loc (getpos n cs))
		        | otherwise = (Global n m ar)
pToV cs (Loc l) = Loc l
pToV cs (GVar x) = GVar x
pToV cs (GConst c) = GConst c
pToV cs (Lambda iv ns sc) = Lambda iv ns (pToV cs sc)
pToV cs (Closure ns rt sc) = Closure ns rt (pToV cs sc)
pToV cs (Bind n t v sc) = Bind n t (pToV cs v) (pToV (cs++[(n,(t,[Public]))]) sc)
pToV cs (Declare f l (n,loc) t sc) = Declare f l (n,loc) t (pToV (cs++[(n,(t,[Public]))]) sc)
pToV cs (Return r) = Return (pToV cs r)
pToV cs (Assign l e) = Assign (pToVlval l) (pToV cs e)
  where pToVlval (AName i) = AName i
	pToVlval (AGlob i) = AGlob i
	pToVlval (AIndex l r) = AIndex (pToVlval l) (pToV cs r)
	pToVlval (AField l n a t) = AField (pToVlval l) n a t
pToV cs (AssignOp op l e) = AssignOp op (pToVlval l) (pToV cs e)
  where pToVlval (AName i) = AName i
	pToVlval (AGlob i) = AGlob i
	pToVlval (AIndex l r) = AIndex (pToVlval l) (pToV cs r)
	pToVlval (AField l n a t) = AField (pToVlval l) n a t
pToV cs (Seq a b) = Seq (pToV cs a) (pToV cs b)
pToV cs (Apply f as) = Apply (pToV cs f) (fmap (pToV cs) as)
pToV cs (Partial f as i) = Partial (pToV cs f) (fmap (pToV cs) as) i
pToV cs (Foreign ty f as) = Foreign ty f 
			    (fmap (\ (x,y) -> ((pToV cs x),y)) as)
pToV cs (While t e) = While (pToV cs t) (pToV cs e)
pToV cs (DoWhile e t) = DoWhile (pToV cs e) (pToV cs t)
pToV cs (For x nm y l ar e) = For x nm y (pToVlval l) (pToV cs ar) (pToV cs e)
  where pToVlval (AName i) = AName i
	pToVlval (AIndex l r) = AIndex (pToVlval l) (pToV cs r)
pToV cs (TryCatch e1 e2 n f) = TryCatch (pToV cs e1) (pToV cs e2) 
			                (pToV cs n) (pToV cs f)
pToV cs (Throw e) = Throw (pToV cs e)
pToV cs (Except e1 e2) = Except (pToV cs e1) (pToV cs e2)
pToV cs (InferPrint e t f l) = InferPrint (pToV cs e) t f l
pToV cs (PrintStr e) = PrintStr (pToV cs e)
pToV cs (PrintNum e) = PrintNum (pToV cs e)
pToV cs (PrintExc e) = PrintExc (pToV cs e)
pToV cs (Infix op a b) = Infix op (pToV cs a) (pToV cs b)
pToV cs (InferInfix op a b ts f l) = InferInfix op (pToV cs a) (pToV cs b) ts f l
pToV cs (Append a b) = Append (pToV cs a) (pToV cs b)
pToV cs (Unary op a) = Unary op (pToV cs a)
pToV cs (InferUnary op a ts f l) = InferUnary op (pToV cs a) ts f l
pToV cs (Coerce t1 t2 v) = Coerce t1 t2 (pToV cs v)
pToV cs (InferCoerce t1 t2 v f l) = InferCoerce t1 t2 (pToV cs v) f l
pToV cs (Case t e) = Case (pToV cs t) (pvAlt e)
  where pvAlt [] = []
        pvAlt ((Default ex):xs) = ((Default (pToV cs ex)):pvAlt xs)
        pvAlt ((ConstAlt pt c ex):xs) = ((ConstAlt pt c (pToV cs ex)):pvAlt xs)
	pvAlt ((Alt n t exs ex):xs) = (Alt n t (map (pToV cs) exs) (pToV cs ex)):
				      (pvAlt xs)
pToV cs (ArrayInit xs) = ArrayInit (map (pToV cs) xs)
pToV cs (If a t e) = If (pToV cs a) (pToV cs t) (pToV cs e)
pToV cs (Index l es) = Index (pToV cs l) (pToV cs es)
pToV cs (Field v n a t) = Field (pToV cs v) n a t
pToV cs Noop = Noop
pToV cs VMPtr = VMPtr
pToV cs (Break f l) = Break f l
pToV cs VoidReturn = VoidReturn
pToV cs (Metavar f l i) = Metavar f l i
pToV cs (Annotation a e) = Annotation a (pToV cs e)

-- Check whether the two types are equal (up to alpha conversion of type vars)
checkEq :: Monad m => String -> Int -> Type -> Type -> m ()
checkEq file line t1 t2 = do foo <- cg t1 t2 []
			     return ()
  where
     cg (TyVar x) (TyVar y) tvm = 
	 case (lookup x tvm) of
	   (Just z) -> if y==z then return tvm
		        else fail $ file ++ ":" ++ show line ++ ":" ++
			       "Inferred type less general than given type"
			       ++ " - Inferred " ++ show t1 ++ ", given " 
			       ++ show t2
	   Nothing -> return $ (x,y):tvm
     cg t (TyVar y) tvm = fail $ file ++ ":" ++ show line ++ ":" ++
			    "Inferred type less general than given type"
			    ++ " - Inferred " ++ show t1 ++ ", given " 
			    ++ show t2
     cg (Array x) (Array y) tvm = cg x y tvm
     cg (Fn ns ts t) (Fn ns' ts' t') tvm = do
          tvm' <- cg t t' tvm
	  cgl ts ts' tvm'
     cg (User n ts) (User n' ts') tvm = cgl ts ts' tvm
     cg _ _ tvm = return tvm

     cgl [] [] tvm = return tvm
     cgl (x:xs) (y:ys) tvm = do tvm' <- cg x y tvm
				cgl xs ys tvm'


-- Return whether an expression returns a value
containsReturn :: Expr Name -> Bool
containsReturn (Return _) = True
containsReturn (Throw _) = True -- kind of the same thing!
containsReturn (Lambda _ _ e) = containsReturn e
containsReturn (Bind n t e1 e2) = containsReturn e2
containsReturn (Declare _ _ _ _ e) = containsReturn e
containsReturn (Seq e1 e2) = containsReturn e1 || containsReturn e2
containsReturn (While e1 e2) = containsReturn e2
containsReturn (DoWhile e1 e2) = containsReturn e1
containsReturn (For _ _ _ _ _ e) = containsReturn e
containsReturn (Case _ alts) = acr alts
   where acr [] = False
         acr [(Default r)] = containsReturn r
         acr [(ConstAlt _ _ r)] = containsReturn r
         acr [(Alt _ _ ts r)] = containsReturn r
	 acr ((Alt _ _ ts r):rs) = containsReturn r && acr rs
         acr ((Default r):rs) = containsReturn r && acr rs
         acr ((ConstAlt _ c r):rs) = containsReturn r && acr rs
containsReturn (If a t e) = containsReturn t && containsReturn e
containsReturn (TryCatch tr ca _ f) = (containsReturn tr && containsReturn ca)
				      || containsReturn f
containsReturn (Annotation _ e) = containsReturn e
containsReturn _ = False

-- Return whether a function type needs a runtime check on its result (True) 
-- or has been suitably checked at compile time (False)
-- Functions which return a type variable which does not appear in the 
-- arguments (e.g. unmarshal or subvert, which this function really exists for)
-- need a runtime check

needsCheck :: Type -> Bool
-- All type variables in r need to occur in args
needsCheck (Fn _ args r) 
    = let rvars = getTyVars r
          argvars = concat (map getTyVars args) in
       not $ length rvars == 0 || and (map (`elem` argvars) rvars)
needsCheck _ = False -- Not a function type
