(**************************************************************************)
(*                                                                        *)
(*     The Alt-ergo theorem prover                                        *)
(*     Copyright (C) 2006-2008                                            *)
(*                                                                        *)
(*     Sylvain Conchon                                                    *)
(*     Evelyne Contejean                                                  *)
(*     CNRS-LRI-Universite Paris Sud                                      *)
(*                                                                        *)
(*   This file is distributed under the terms of the CeCILL-C licence     *)
(*                                                                        *)
(**************************************************************************)

open Format
open Options
open Sig

module type S = sig
  type t

  val empty : unit -> t
  val add : Literal.t -> t -> t
  val assume : Literal.t -> Explanation.t -> t -> t
  val query : Literal.t -> t -> bool
  val class_of : t -> Term.t -> Term.t list
  val explain : Literal.t -> t -> Explanation.t
end

module Make (X : Sig.X) = struct    

  module Ex = Explanation
  module Uf = Uf.Make(X)
  module SetF = Formula.Set
  module T = Term
  module A = Literal
  module SetT = Term.Set
  module SetA = A.Set
  module SAeq = A.SetEq

  module MX = struct 
    include Map.Make(struct type t=X.r include X end)
    let find k m = try find k m with Not_found -> (SetT.empty,SetA.empty)
  end
    
  type t = { 
    use : (SetT.t * SetA.t) MX.t ; 
    uf : Uf.t ;
    relation : X.R.t
  }

  module Print = struct

    let sterms fmt = SetT.iter (fprintf fmt "%a " T.print)
    let satoms fmt = SetA.iter (fprintf fmt "%a " A.print)
    let lrepr fmt = List.iter (fprintf fmt "%a " X.print)

    let use env = 
      fprintf fmt "@{<C.Bold>[cc]@} use table:\n";
      MX.iter 
	(fun t (st,sa) -> 
	   fprintf fmt "%a is used by {%a} and {%a} \n"  
	     X.print t sterms st satoms sa) 
	env.use

    let congruent t1 t2 = 
      fprintf fmt "@{<C.Bold>[cc]@} cong %a=%a ?@." T.print t1 T.print t2

    let add_to_use t = fprintf fmt "@{<C.Bold>[cc]@} add_to_use: %a@." T.print t
	
    let leaves t lvs = 
      fprintf fmt "@{<C.Bold>[cc]@} leaves of %a@.@." T.print t; lrepr fmt lvs
  end

  let compat_leaves env = 
    List.fold_left2
      (fun dep x y -> Ex.union (Uf.explain env.uf x y) dep) Ex.empty

  let congruent env u1 u2 = 
    if debug_cc then Print.congruent u1 u2;
    let {T.f=f1;xs=xs1;ty=ty1} = T.view u1 in
    let {T.f=f2;xs=xs2;ty=ty2} = T.view u2 in
    if Symbols.equal f1 f2 && Ty.equal ty1 ty2 then
      compat_leaves env xs1 xs2 
    else raise Exception.NotCongruent
	
  let one = X.make (Term.make (Symbols.name "@bottom") [] Ty.Tint)
      
  let leaves r = match X.leaves r with [] -> [one] | l -> l

  let concat_leaves uf l = 
    let rec concat_rec acc t = 
      match  X.leaves (Uf.find uf t) , acc with
	  [] , _ -> acc
	| res, [] -> res
	| res , _ -> List.rev_append res acc
    in
    match List.fold_left concat_rec [] l with
	[] -> [one]
      | res -> res

  let rec close_up t1 t2 dep env =
    if debug_cc then 
      printf "@{<C.Bold>[cc]@} close_up: %a=%a@." T.print t1 T.print t2;
    if Uf.equal env.uf t1 t2 then env
    else
      let uf, res = Uf.union env.uf t1 t2 dep in
      List.fold_left (fun env (p,touched,v) ->
      	let (st_ftr_p, sa_ftr_p) as ftr_p = MX.find p env.use in
	let st_others, sa_others = 
	  List.fold_left
	    (fun (st,sa) r -> 
	       let (rst,rsa) = 
                 match leaves r with
                   | [] -> (SetT.empty,SetA.empty)
                   | a::l -> List.fold_left (fun (st,sa) l -> 
                                let st_l,sa_l = MX.find l env.use in
                                SetT.inter st st_l, SetA.inter sa sa_l)
                       (MX.find a env.use) l in
                 (SetT.union st rst,SetA.union sa rsa))
	    ftr_p touched 
	in
	let use = 
	  List.fold_left 
	    (fun use l -> 
	       let st_l , sa_l = MX.find l env.use in
	       let st = SetT.union st_l st_ftr_p in
	       let sa = SetA.union sa_l sa_ftr_p in
	       MX.add l (st,sa) use)
	     env.use (leaves v) (* prevoir le menage de r dans use *) in
	let env = SetT.fold 
	  (fun x env -> 
	     SetT.fold 
	       (fun y env -> try
		  close_up x y (congruent env x y) env
		with Exception.NotCongruent -> env)
	       st_others env)
	  st_ftr_p {env with use=use}
	in
	replay_atom (SetA.union sa_ftr_p sa_others) env
                     ) {env with uf=uf} res

  and replay_atom sa env = 
    let rel , nsa  = 
      X.R.assume env.relation sa (Uf.find env.uf) (Uf.class_of env.uf) in
    SAeq.fold 
      (fun (a,t1,t2) env -> 
	 let r1 = Uf.find env.uf t1 in
	 let r2 = Uf.find env.uf t2 in
	 let st_r1 , sa_r1 = MX.find r1 env.use in
	 let st_r2 , sa_r2 = MX.find r2 env.use in
	 let sa_r1' , sa_r2' = SetA.remove a sa_r1 , SetA.remove a sa_r2 in
	 let use =  MX.add r1 (st_r1,sa_r1') env.use in
	 let use =  MX.add r2 (st_r2,sa_r2') use in	     
	 let env = close_up t1 t2 Ex.everything { env with use = use} in
	 env
      ) nsa { env with relation = rel }
  

  let congruents e t s acc = 
    SetT.fold 
      (fun t' acc -> try 
	 (t,t',congruent e t t')::acc 
       with Exception.NotCongruent -> acc) s acc
	      	
  let rec add_term (env,ct) t = 
    if debug_cc then Print.add_to_use t;
    if Uf.mem env.uf t then (env,ct)
    else
      let env = { env with uf = Uf.add env.uf t} in
      let {T.xs=xs} = T.view t in
      let env , ct = List.fold_left add_term (env,ct) xs in
      let rt = Uf.find env.uf t in
      let env = 
	if MX.mem rt env.use then env 
	else { env with use = MX.add rt (SetT.empty,SetA.empty) env.use }
      in
      let lvs = concat_leaves env.uf xs in
      let env,st_uset = 
        match lvs with
          | [] -> (env,SetT.empty)
          | a::l -> List.fold_left (fun (env,ist_uset) rx -> 
	               let st_uset , sa_uset = MX.find rx env.use in
	               {env with use=
			   MX.add rx (SetT.add t st_uset,sa_uset) env.use} ,
				      SetT.inter ist_uset st_uset) 
              (let st_uset , sa_uset = MX.find a env.use in
	       {env with use=MX.add a (SetT.add t st_uset,sa_uset) env.use},
               st_uset) l in
      (env,congruents env t st_uset ct)
	
  let add a env =
    let st = A.terms_of a in
    let env = 
      SetT.fold
	(fun t env -> 
	   let env , ct = add_term (env,[]) t in
	   List.fold_left
	     (fun e (x,y,dep) -> close_up x y dep e) env ct) st env
    in 
    match A.view a with
	A.Eq _ | A.Neq _ -> env
      | _ ->
	  let lvs = concat_leaves env.uf (Term.Set.elements st) in
	  List.fold_left
	    (fun env rx ->
	       let st_uset , sa_uset = MX.find rx env.use in
	       {env with use=MX.add rx (st_uset,SetA.add a sa_uset) env.use}
	    ) env lvs

  let assume a dep env = 
    if debug_cc then Print.use env;
    let env  = add a env in
    match A.view a with
	A.Eq(t1,t2) -> close_up t1 t2 dep env
      | A.Neq(t1,t2)-> 
	  let env = { env with uf=Uf.distinct env.uf t1 t2 dep } in
	  replay_atom (SetA.singleton a) env
      | _ -> replay_atom (SetA.singleton a) env

  let class_of env t = Uf.class_of env.uf t

  let explain a env = try
    (match A.view a with
      | A.Eq(x,y) -> Uf.explain env.uf x y
      | A.Neq(x,y) -> Uf.neq_explain env.uf x y
      | _ -> Ex.everything)
  with Exception.NotCongruent -> assert false
	      
  let query a env = 
    if debug_cc then Print.use env;
    match A.view a with
	A.Eq(t1,t2) -> Uf.equal env.uf t1 t2
      | A.Neq(t1,t2) -> Uf.are_distinct env.uf t1 t2
      | _ -> X.R.query (A.neg a) (Uf.find env.uf) (class_of env) env.relation
    
  let empty _ = 
    let env = { 
      use = MX.empty ; 
      uf = Uf.empty ; 
      relation = X.R.empty () }
    in
    assume (A.make (A.Neq(T.vrai,T.faux))) Ex.empty env 

end
