(* Copyright (C) 2005, HELM Team.
 * 
 * This file is part of HELM, an Hypertextual, Electronic
 * Library of Mathematics, developed at the Computer Science
 * Department, University of Bologna, Italy.
 * 
 * HELM is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * HELM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with HELM; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston,
 * MA  02111-1307, USA.
 * 
 * For details, see the HELM World-Wide-Web page,
 * http://cs.unibo.it/helm/.
 *)

(* $Id: discrimination_tree.ml 7922 2007-11-25 13:12:25Z tassi $ *)

module DiscriminationTreeIndexing =  
  functor (A:Set.S) -> 
    struct

      type path_string_elem = 
        | Constant of UriManager.uri 
        | Bound of int | Variable | Proposition | Datatype | Dead;;
      type path_string = path_string_elem list;;


      (* needed by the retrieve_* functions, to know the arities of the
       * "functions" *)
      
      let ppelem = function
        | Constant uri -> UriManager.name_of_uri uri
        | Bound i -> string_of_int i
        | Variable -> "?"
        | Proposition -> "Prop"
        | Datatype -> "Type"
        | Dead -> "DEAD"
      ;;
      let pppath l = String.concat "::" (List.map ppelem l) ;;
      let elem_of_cic = function
        | Cic.Meta _ | Cic.Implicit _ -> Variable
        | Cic.Rel i -> Bound i
        | Cic.Sort (Cic.Prop) -> Proposition
        | Cic.Sort _ -> Datatype
        | Cic.Const _ | Cic.Var _ | Cic.MutInd _ | Cic.MutConstruct _ as t ->
            (try Constant (CicUtil.uri_of_term t)
            with Invalid_argument _ -> assert false)
        | Cic.Appl _ -> 
            assert false (* should not happen *)
        | Cic.LetIn _ | Cic.Lambda _ | Cic.Prod _ | Cic.Cast _
        | Cic.MutCase _ | Cic.Fix _ | Cic.CoFix _ -> 
            HLog.debug "FIXME: the trie receives an invalid term";
            Dead
            (* assert false universe.ml removes these *)
      ;;
      let path_string_of_term arities =
	let set_arity arities k n = 
	  (assert (k<>Variable || n=0);
          if k = Dead then arities else (k,n)::(List.remove_assoc k arities))
        in
        let rec aux arities = function
          | Cic.Appl ((hd::tl) as l) ->
              let arities = 
		set_arity arities (elem_of_cic hd) (List.length tl) in
	      List.fold_left 
		(fun (arities,path) t -> 
		   let arities,tpath = aux arities t in
		     arities,path@tpath)
		(arities,[]) l
          | t -> arities, [elem_of_cic t]
        in 
          aux arities
      ;;
      let compare_elem e1 e2 =
        match e1,e2 with
        | Constant u1,Constant u2 -> UriManager.compare u1 u2
        | e1,e2 -> Pervasives.compare e1 e2
      ;;

      module OrderedPathStringElement = struct
        type t = path_string_elem
        let compare = compare_elem
      end

      module PSMap = Map.Make(OrderedPathStringElement);;

      type key = PSMap.key

      module DiscriminationTree = Trie.Make(PSMap);;

      type t = A.t DiscriminationTree.t * (path_string_elem*int) list
      let empty = DiscriminationTree.empty, [] ;;

      let index (tree,arity) term info =
        let arity,ps = path_string_of_term arity term in
        let ps_set =
          try DiscriminationTree.find ps tree 
          with Not_found -> A.empty in
        let tree = DiscriminationTree.add ps (A.add info ps_set) tree in
        tree,arity
      ;;

      let remove_index (tree,arity) term info =
        let arity,ps = path_string_of_term arity term in
        try
          let ps_set = A.remove info (DiscriminationTree.find ps tree) in
          if A.is_empty ps_set then
            DiscriminationTree.remove ps tree,arity
          else
            DiscriminationTree.add ps ps_set tree,arity
        with Not_found ->
          tree,arity
      ;;

      let in_index (tree,arity) term test =
        let arity,ps = path_string_of_term arity term in
        try
          let ps_set = DiscriminationTree.find ps tree in
          A.exists test ps_set
        with Not_found ->
          false
      ;;

      let head_of_term = function
        | Cic.Appl (hd::tl) -> hd
        | term -> term
      ;;

      let rec skip_prods = function
        | Cic.Prod (_,_,t) -> skip_prods t
        | term -> term
      ;;

      let rec subterm_at_pos pos term =
        match pos with
          | [] -> term
          | index::pos ->
              match term with
                | Cic.Appl l ->
                    (try subterm_at_pos pos (List.nth l index)
                     with Failure _ -> raise Not_found)
                | _ -> raise Not_found
      ;;


      let rec after_t pos term =
        let pos' =
          match pos with
            | [] -> raise Not_found
            | pos -> 
                List.fold_right 
                  (fun i r -> if r = [] then [i+1] else i::r) pos []
        in
          try
            ignore(subterm_at_pos pos' term ); pos'
          with Not_found ->
            let pos, _ =
              List.fold_right
                (fun i (r, b) -> if b then (i::r, true) else (r, true))
                pos ([], false)
            in
              after_t pos term
      ;;


      let next_t pos term =
        let t = subterm_at_pos pos term in
          try
            let _ = subterm_at_pos [1] t in
              pos @ [1]
          with Not_found ->
            match pos with
              | [] -> [1]
              | pos -> after_t pos term
      ;;     

      let retrieve_generalizations (tree,arity) term =
        let term = skip_prods term in
        let rec retrieve tree term pos =
          match tree with
            | DiscriminationTree.Node (Some s, _) when pos = [] -> s
            | DiscriminationTree.Node (_, map) ->
                let res =
                  let hd_term = 
                    elem_of_cic (head_of_term (subterm_at_pos pos term)) 
                  in
                  if hd_term = Variable then A.empty else 
                  try
                    let n = PSMap.find hd_term map in
                      match n with
                        | DiscriminationTree.Node (Some s, _) -> s
                        | DiscriminationTree.Node (None, _) ->
                            let newpos = 
                              try next_t pos term 
                              with Not_found -> [] 
                            in
                              retrieve n term newpos
                  with Not_found ->
                    A.empty
                in
                  try
                    let n = PSMap.find Variable map in
                    let newpos = try after_t pos term with Not_found -> [-1] in
                      if newpos = [-1] then
                        match n with
                          | DiscriminationTree.Node (Some s, _) -> A.union s res
                          | _ -> res
                      else
                        A.union res (retrieve n term newpos)
                  with Not_found ->
                    res
        in
          retrieve tree term []
      ;;


      let jump_list arities = function
        | DiscriminationTree.Node (value, map) ->
            let rec get n tree =
              match tree with
                | DiscriminationTree.Node (v, m) ->
                    if n = 0 then
                      [tree]
                    else
                      PSMap.fold
                        (fun k v res ->
                           let a =
                             try List.assoc k arities 
                             with Not_found -> 0 
                           in
                             (get (n-1 + a) v) @ res) m []
            in
              PSMap.fold
                (fun k v res ->
                   let arity = 
		     try 
		       List.assoc k arities 
		     with Not_found -> 0 in
                     (get arity v) @ res)
                map []
      ;;


      let retrieve_unifiables (tree,arities) term =
        let term = skip_prods term in
        let rec retrieve tree term pos =
          match tree with
            | DiscriminationTree.Node (Some s, _) when pos = [] -> s
            | DiscriminationTree.Node (_, map) ->
                let subterm =
                  try Some (subterm_at_pos pos term) with Not_found -> None
                in
                match subterm with
                | None -> A.empty
                | Some (Cic.Meta _) ->
                      let newpos = try next_t pos term with Not_found -> [] in
                      let jl = jump_list arities tree in
                        List.fold_left
                          (fun r s -> A.union r s)
                          A.empty
                          (List.map (fun t -> retrieve t term newpos) jl)
                  | Some subterm ->
                      let res = 
                        let hd_term = elem_of_cic (head_of_term subterm) in
                          if hd_term = Variable then
			   A.empty else
                        try
                          let n = PSMap.find hd_term map in
                            match n with
                              | DiscriminationTree.Node (Some s, _) -> s
                              | DiscriminationTree.Node (None, _) ->
                                  retrieve n term (next_t pos term)
                        with Not_found ->
                          A.empty
                      in
                        try
                          let n = PSMap.find Variable map in
                          let newpos = 
                            try after_t pos term 
                            with Not_found -> [-1] 
                          in
                            if newpos = [-1] then
                              match n with
                                | DiscriminationTree.Node (Some s, _) -> 
                                    A.union s res
                                | _ -> res
                            else
                              A.union res (retrieve n term newpos)
                        with Not_found ->
                          res
      in
        retrieve tree term []
  end
;;

