(* Copyright (C) 2005, HELM Team.
 * 
 * This file is part of HELM, an Hypertextual, Electronic
 * Library of Mathematics, developed at the Computer Science
 * Department, University of Bologna, Italy.
 * 
 * HELM is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * HELM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with HELM; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston,
 * MA  02111-1307, USA.
 * 
 * For details, see the HELM World-Wide-Web page,
 * http://cs.unibo.it/helm/.
 *)

(* $Id: discrimination_tree.ml 8991 2008-09-19 12:47:23Z tassi $ *)

type 'a path_string_elem = 
  | Constant of 'a * int (* name, arity *)
  | Bound of int * int (* rel, arity *)
  | Variable (* arity is 0 *)
  | Proposition (* arity is 0 *) 
  | Datatype (* arity is 0 *) 
  | Dead (* arity is 0 *) 
;;  

type 'a path = ('a path_string_elem) list;;

module type Indexable = sig
  type input
  type constant_name
  val compare: 
    constant_name path_string_elem -> 
    constant_name path_string_elem -> int
  val string_of_path : constant_name path -> string
  val path_string_of : input -> constant_name path
end

module CicIndexable : Indexable 
with type input = Cic.term and type constant_name = UriManager.uri 
= struct

        type input = Cic.term
        type constant_name = UriManager.uri
        
        let ppelem = function
          | Constant (uri,arity) -> 
              "("^UriManager.name_of_uri uri ^ "," ^ string_of_int arity^")"
          | Bound (i,arity) -> 
              "("^string_of_int i ^ "," ^ string_of_int arity^")"
          | Variable -> "?"
          | Proposition -> "Prop"
          | Datatype -> "Type"
          | Dead -> "Dead"
        ;;

        let path_string_of =
          let rec aux arity = function
            | Cic.Appl ((Cic.Meta _|Cic.Implicit _)::_) -> [Variable]
            | Cic.Appl (Cic.Lambda _ :: _) -> 
                [Variable] (* maybe we should b-reduce *)
            | Cic.Appl [] -> assert false
            | Cic.Appl (hd::tl) ->
                aux (List.length tl) hd @ List.flatten (List.map (aux 0) tl) 
            | Cic.Cast (t,_) -> aux arity t
            | Cic.Lambda (_,s,t) | Cic.Prod (_,s,t) -> [Variable]
                (* I think we should CicSubstitution.subst Implicit t *)
            | Cic.LetIn (_,s,_,t) -> [Variable] (* z-reduce? *)
            | Cic.Meta _ | Cic.Implicit _ -> assert (arity = 0); [Variable]
            | Cic.Rel i -> [Bound (i, arity)]
            | Cic.Sort (Cic.Prop) -> assert (arity=0); [Proposition]
            | Cic.Sort _ -> assert (arity=0); [Datatype]
            | Cic.Const _ | Cic.Var _ 
            | Cic.MutInd _ | Cic.MutConstruct _ as t ->
                [Constant (CicUtil.uri_of_term t, arity)]
            | Cic.MutCase _ | Cic.Fix _ | Cic.CoFix _ -> [Dead]
          in 
            aux 0
        ;;

        let compare e1 e2 =
          match e1,e2 with
          | Constant (u1,a1),Constant (u2,a2) -> 
               let x = UriManager.compare u1 u2 in
               if x = 0 then Pervasives.compare a1 a2 else x
          | e1,e2 -> Pervasives.compare e1 e2
        ;;
        
        let string_of_path l = String.concat "." (List.map ppelem l) ;;
end 

let arity_of = function
  | Constant (_,a) 
  | Bound (_,a) -> a
  | _ -> 0 
;;

module type DiscriminationTree =
    sig

      type input 
      type data
      type dataset
      type constant_name
      type t

      val iter : t -> (constant_name path -> dataset -> unit) -> unit

      val empty : t
      val index : t -> input -> data -> t
      val remove_index : t -> input -> data -> t
      val in_index : t -> input -> (data -> bool) -> bool
      val retrieve_generalizations : t -> input -> dataset
      val retrieve_unifiables : t -> input -> dataset
    end

module Make (I:Indexable) (A:Set.S) : DiscriminationTree 
with type constant_name = I.constant_name and type input = I.input
and type data = A.elt and type dataset = A.t =

    struct

      module OrderedPathStringElement = struct
        type t = I.constant_name path_string_elem
        let compare = I.compare
      end

      type constant_name = I.constant_name
      type data = A.elt
      type dataset = A.t
      type input = I.input

      module PSMap = Map.Make(OrderedPathStringElement);;

      type key = PSMap.key

      module DiscriminationTree = Trie.Make(PSMap);;

      type t = A.t DiscriminationTree.t

      let empty = DiscriminationTree.empty;;

      let iter dt f = DiscriminationTree.iter (fun p x -> f p x) dt;;

      let index tree term info =
        let ps = I.path_string_of term in
        let ps_set =
          try DiscriminationTree.find ps tree with Not_found -> A.empty 
        in
        DiscriminationTree.add ps (A.add info ps_set) tree
      ;;

      let remove_index tree term info =
        let ps = I.path_string_of term in
        try
          let ps_set = A.remove info (DiscriminationTree.find ps tree) in
          if A.is_empty ps_set then DiscriminationTree.remove ps tree
          else DiscriminationTree.add ps ps_set tree
        with Not_found -> tree
      ;;

      let in_index tree term test =
        let ps = I.path_string_of term in
        try
          let ps_set = DiscriminationTree.find ps tree in
          A.exists test ps_set
        with Not_found -> false
      ;;

      (* You have h(f(x,g(y,z)),t) whose path_string_of_term_with_jl is 
         (h,2).(f,2).(x,0).(g,2).(y,0).(z,0).(t,0) and you are at f and want to
         skip all its progeny, thus you want to reach t.
      
         You need to skip as many elements as the sum of all arieties contained
          in the progeny of f.
      
         The input ariety is the one of f while the path is x.g....t  
         Should be the equivalent of after_t in the literature (handbook A.R.)
       *)
      let rec skip arity path =
        if arity = 0 then path else match path with 
        | [] -> assert false 
        | m::tl -> skip (arity-1+arity_of m) tl
      ;;

      (* the equivalent of skip, but on the index, thus the list of trees
         that are rooted just after the term represented by the tree root
         are returned (we are skipping the root) *)
      let skip_root = function DiscriminationTree.Node (value, map) ->
        let rec get n = function DiscriminationTree.Node (v, m) as tree ->
           if n = 0 then [tree] else 
           PSMap.fold (fun k v res -> (get (n-1 + arity_of k) v) @ res) m []
        in
          PSMap.fold (fun k v res -> (get (arity_of k) v) @ res) map []
      ;;

      let retrieve unif tree term =
        let path = I.path_string_of term in
        let rec retrieve path tree =
          match tree, path with
          | DiscriminationTree.Node (Some s, _), [] -> s
          | DiscriminationTree.Node (None, _), [] -> A.empty 
          | DiscriminationTree.Node (_, map), Variable::path when unif ->
              List.fold_left A.union A.empty
                (List.map (retrieve path) (skip_root tree))
          | DiscriminationTree.Node (_, map), node::path ->
              A.union
                 (if not unif && node = Variable then A.empty else
                  try retrieve path (PSMap.find node map)
                  with Not_found -> A.empty)
                 (try
                    match PSMap.find Variable map,skip (arity_of node) path with
                    | DiscriminationTree.Node (Some s, _), [] -> s
                    | n, path -> retrieve path n
                  with Not_found -> A.empty)
       in
        retrieve path tree
      ;;

      let retrieve_generalizations tree term = retrieve false tree term;;
      let retrieve_unifiables tree term = retrieve true tree term;;
  end
;;

