-- HatCover: Colin Runciman, University of York, October 2004
-- Print Haskell sources highlighting the maximal expressions with
-- instances recorded in a given Hat trace.

import HighlightStyle (highlightOn, highlightOff, Highlight(..), Colour(..))
import List (isPrefixOf, isSuffixOf)
import IO (stderr, hPutStrLn)
import System (getArgs, getProgName, exitWith, ExitCode(..))
import LowLevel (openHatFile, NodeType(..), nodeSequence, FileNode(..), getSrcRef)
import FFIExtensions (withCString)
import SrcRef (SrcRef(..), readSrcRef)

type SrcRef' = (String, Interval LC)
type Interval a = (a,a)
type LC = (Int,Int)

includes :: Ord a => Interval a -> Interval a -> Bool
includes (a,d) (b,c) = a <= b && c <= d

precedes :: Ord a => Interval a -> Interval a -> Bool
precedes (_,b) (c,_) = b < c

-- The record of which expressions have been covered is represented as a
-- tree of source intervals for each source-file.

type Cover = [(String, Tree (Interval LC))]
data Tree a = Leaf | Fork (Tree a) a (Tree a)

normalTree :: Tree a -> Bool
normalTree Leaf         = True
normalTree (Fork _ _ _) = True

fork :: Tree a -> a -> Tree a -> Tree a
fork back x forw | normalTree back && normalTree forw = Fork back x forw

normalList :: [a] -> Bool
normalList []    = True
normalList (_:_) = True

cons :: a -> [a] -> [a]
cons x xs | normalList xs = x : xs

add :: SrcRef' -> Cover -> Cover
add (f,i) [] = [(f, single i)]
add (f,i) ((g,t):etc) =
  if f==g then coverCons (g,insert i t) etc
  else cons (g,t) (add (f,i) etc)
  where
  coverCons st@(s,t) cvr | normalTree t = st : cvr

addAll :: [SrcRef'] -> Cover -> Cover
addAll []       cvr | normalList cvr = cvr
addAll (sr:srs) cvr | normalList cvr = addAll srs (add sr cvr)

single :: a -> Tree a
single x = Fork Leaf x Leaf

insert :: Ord a => Interval a -> Tree (Interval a) -> Tree (Interval a)
insert i Leaf = single i
insert i t@(Fork back j forw) =
  if      i `precedes` j then fork (insert i back) j forw
  else if j `precedes` i then fork back j (insert i forw)
  else if j `includes` i then t
  else {- i `includes` j -} fork (back `outside` i) i (forw `outside` i)
  
outside :: Ord a => Tree (Interval a) -> Interval a -> Tree (Interval a)
outside Leaf _ = Leaf
outside (Fork back ab forw) cd =
  if      ab `precedes` cd then fork back ab (forw `outside` cd)
  else if cd `precedes` ab then fork (back `outside` cd) ab forw
  else {- cd `includes` ab -} graft (back `outside` cd) (forw `outside` cd) 

graft :: Tree (Interval a) -> Tree (Interval a) -> Tree (Interval a)
graft Leaf t = t
graft (Fork back i forw) t = fork back i (graft forw t)

flatten :: Tree a -> [a]
flatten Leaf = []
flatten (Fork back i forw) = flatten back ++ [i] ++ flatten forw

printCover :: (String, String) -> Cover -> IO ()
printCover hiOnOff = mapM_ (printModuleCover hiOnOff)

printModuleCover :: (String, String) -> (String, (Tree (Interval LC))) -> IO ()
printModuleCover hiOnOff (f, c) =
  do
    src <- readFile f
    printLo hiOnOff (1,1) (flatten c) (map expand (lines src))

printLo :: (String, String) -> LC -> [Interval LC] -> [String] -> IO ()
printLo _ _      [] srcLines =
  mapM_ putStrLn srcLines
printLo hiOnOff (lineNo,colNo) (((lstart,cstart),(lstop,cstop)):ivals) srcLines =
  do
    mapM_ putStrLn (take lnLo srcLines)
    putStr (take chLo (head srcLines'))
    printHi hiOnOff
      (lstart,cstart) (lstop,cstop) ivals
      (drop chLo (head srcLines') : tail srcLines')
  where
  lnLo = lstart-lineNo
  chLo = cstart-(if lnLo==0 then colNo else 1)
  srcLines' = drop lnLo srcLines 
  
printHi :: (String, String) -> LC -> LC -> [Interval LC] -> [String] -> IO ()
printHi hiOnOff (lineNo,colNo) (lstop,cstop) ivals srcLines =
  do
    mapM_ (putStrLn . high hiOnOff) (take lnHi srcLines)
    putStr (high hiOnOff (take chHi (head srcLines')))
    printLo hiOnOff
      (lstop,cstop+1) ivals
      (drop chHi (head srcLines') : tail srcLines')    
  where
  lnHi = lstop-lineNo
  chHi = 1+cstop-(if lnHi==0 then colNo else 1)
  srcLines' = drop lnHi srcLines
  
high :: (String, String) -> String -> String
high (hiOn, hiOff) s =
  takeWhile (==' ') s ++ hiOn ++ dropWhile (== ' ') s ++ hiOff

main =
  do
    args    <- getArgs
    prog    <- getProgName
    let (options, nonOptions) = span ("-" `isPrefixOf`) args
    let hiOnOpt  = [drop (length "-hion=") opt | opt <- options,
                                                 "-hion=" `isPrefixOf` opt]
    let hiOn     = if null hiOnOpt then highlightOn [Bold] else head hiOnOpt
    let hiOffOpt = [drop (length "-hioff=") opt | opt <- options,
                                                  "-hioff=" `isPrefixOf` opt]
    let hiOff    = if null hiOnOpt then highlightOff else head hiOffOpt
    hatfile <- case nonOptions of
               (t:_) -> return (rectify ".hat" t)
               _     -> do hPutStrLn stderr (prog++": no trace file")
                           exitWith (ExitFailure 1)
    withCString prog (withCString hatfile . openHatFile)
    nodes   <- nodeSequence
    let moduleNames = map (rectify ".hs") (tail nonOptions)
    let srs = [ convert sr |
                (fn,nt) <- nodes, isCover nt,
                let srn = getSrcRef fn, srn /= FileNode 0,
		let sr = readSrcRef srn, line sr /= 0,
		null moduleNames || filename sr `elem` moduleNames ]
    printCover (hiOn, hiOff) (addAll srs [])

rectify :: String -> FilePath -> FilePath
rectify ext f | ext `isSuffixOf` f = f
              | otherwise          = f ++ ext

isCover :: NodeType -> Bool
isCover ExpApp         = True
isCover ExpValueApp    = True
isCover ExpChar        = True
isCover ExpInt         = True
isCover ExpInteger     = True
isCover ExpRat         = True
isCover ExpRational    = True
isCover ExpFloat       = True
isCover ExpDouble      = True
isCover ExpValueUse    = True
isCover ExpConstUse    = True
isCover ExpFieldUpdate = True
isCover ExpProjection  = True
isCover other          = False

convert :: SrcRef -> SrcRef'
convert (SrcRef fn ls cs le ce) = (fn, ((ls,cs),(le,ce)))

expand :: String -> String
expand = expandFrom 1

expandFrom :: Int -> String -> String
expandFrom _ "" = ""
expandFrom n (x:xs) = f (expandFrom (n+d) xs)
  where
  (d, f) = if x=='\t' then (8 - (n-1)`mod`8, (take d spaces ++))
                      else (1, (x:))

spaces :: String
spaces = repeat ' '


