(*
 * Graph: generic graph library
 * Copyright (C) 2004
 * Sylvain Conchon, Jean-Christophe Filliatre and Julien Signoles
 * 
 * This software is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License version 2, as published by the Free Software Foundation.
 * 
 * This software is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * 
 * See the GNU Library General Public License version 2 for more details
 * (enclosed in the file LGPL).
 *)

(* $Id: per_imp.ml,v 1.20 2004/02/23 10:06:30 signoles Exp $ *)

(* Common implementation to persistent and imperative graphs. *)

open Sig
open Util

let cpt_vertex = ref min_int
  (* global counter for abstract vertex *)

module type VERTEX = sig
  type t 
  val compare : t -> t -> int 
  val hash : t -> int 
  val equal : t -> t -> bool
  type label
  val label : t -> label
  val create : label -> t
end

module type EDGE = sig
  type vertex
  type t
  val src : t -> vertex
  val dst : t -> vertex
  val compare : t -> t -> int
  type label
  val label : t -> label
  val create : vertex -> label -> vertex -> t
end

(* Common signature to an imperative/persistent association table *)
module type HM = sig
  type 'a return
  type 'a t
  type key
  val create : unit -> 'a t
  val empty : 'a return
  val is_empty : 'a t -> bool
  val add : key -> 'a -> 'a t -> 'a return
  val remove : key -> 'a t -> 'a return
  val mem : key -> 'a t -> bool
  val find : key -> 'a t -> 'a
  val find_and_raise : key -> 'a t -> string -> 'a
    (* [find_and_raise k t s] is equivalent to [find k t] but
       raises [Invalid_argument s] when [find k t] raises [Not_found] *)

  val iter : (key -> 'a -> unit) -> 'a t -> unit
  val map : (key -> 'a -> key * 'a) -> 'a t -> 'a t
  val fold : (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b
  val copy : 'a t -> 'a t
end

(* [HM] implementation using hashtbl. *)
module Make_Hashtbl(X: COMPARABLE) = struct
  include Hashtbl.Make(X)
  type 'a return = unit

  let empty = ()
    (* never call and not visible for the user thank's to signature 
       constraints *)

  let create () = create 997

  let is_empty h =
    try
      iter (fun _ -> raise Exit) h;
      true
    with Exit ->
      false

  let add k v h = replace h k v
  let remove k h = remove h k
  let mem k h = mem h k
  let find k h = find h k

  let find_and_raise k h s = try find k h with Not_found -> invalid_arg s

  let map f h = 
    let h' = create ()  in
    iter (fun k v -> let k, v = f k v in add k v h') h;
    h'
    
end

(* [HM] implementation using map *)
module Make_Map(X: COMPARABLE) = struct
  include Map.Make(X)
  type 'a return = 'a t
  let is_empty m = (m = empty)
  let create () = assert false
    (* never call and not visible for the user thank's to 
       signature constraints *)
  let copy m = m
  let map f m = fold (fun k v m -> let k, v = f k v in add k v m) m empty
  let find_and_raise k h s = try find k h with Not_found -> invalid_arg s
end

(* All the predecessor operations from the iterators on the edges *)
module Pred(S: sig
	      module PV: COMPARABLE
	      module PE: EDGE with type vertex = PV.t
	      type t
	      val mem_vertex : PV.t -> t -> bool
	      val iter_edges : (PV.t -> PV.t -> unit) -> t -> unit
	      val fold_edges : (PV.t -> PV.t -> 'a -> 'a) -> t -> 'a -> 'a
	      val iter_edges_e : (PE.t -> unit) -> t -> unit
	      val fold_edges_e : (PE.t -> 'a -> 'a) -> t -> 'a -> 'a
	    end) =
struct

  open S

  let iter_pred f g v = 
    if not (mem_vertex v g) then invalid_arg "iter_pred";
    iter_edges (fun v1 v2 -> if PV.equal v v2 then f v1) g

  let fold_pred f g v = 
    if not (mem_vertex v g) then invalid_arg "fold_pred";
    fold_edges (fun v1 v2 a -> if PV.equal v v2 then f v1 a else a) g

  let pred g v = fold_pred (fun v l -> v :: l) g v []

  let in_degree g v = fold_pred (fun v n -> n + 1) g v 0

  let iter_pred_e f g v =
    if not (mem_vertex v g) then invalid_arg "iter_pred_e";
    iter_edges_e (fun e -> if PV.equal v (PE.dst e) then f e) g

  let fold_pred_e f g v =
    if not (mem_vertex v g) then invalid_arg "fold_pred_e";
    fold_edges_e (fun e a -> if PV.equal v (PE.dst e) then f e a else a) g
      
  let pred_e g v = fold_pred_e (fun v l -> v :: l) g v []

end

(* Common implementation to all (directed) graph implementations. *)
module Minimal(S: Set.S)(HM: HM) = struct

  let is_directed = true
  let empty = HM.empty
  let create = HM.create
  let is_empty = HM.is_empty

  let nb_vertex g = HM.fold (fun _ _ -> succ) g 0
  let nb_edges g = HM.fold (fun _ s n -> n + S.cardinal s) g 0
  let out_degree g v = S.cardinal (HM.find v g)

  let mem_vertex g v1 = HM.mem v1 g

  let unsafe_add_vertex g v = HM.add v S.empty g
  let unsafe_add_edge g v1 v2 = HM.add v1 (S.add v2 (HM.find v1 g)) g

  let iter_vertex f = HM.iter (fun v _ -> f v)
  let fold_vertex f = HM.fold (fun v _ a -> f v a)

end

(* Common implementation to all the unlabeled (directed) graphs. *)
module Unlabeled(V: COMPARABLE)(HM: HM with type key = V.t) = struct
  
  module S = Set.Make(V)

  module E = struct
    type vertex = V.t
    include OTProduct(V)(V)
    let src = fst
    let dst = snd
    type label = unit
    let label _ = ()
    let create v1 () v2 = v1, v2
  end

  let mem_edge g v1 v2 = 
    try
      S.mem v2 (HM.find v1 g)
    with Not_found ->
      false

  let mem_edge_e g (v1, v2) = mem_edge g v1 v2

  let unsafe_remove_edge g v1 v2 = HM.add v1 (S.remove v2 (HM.find v1 g)) g
  let unsafe_remove_edge_e g (v1, v2) = unsafe_remove_edge g v1 v2

  let remove_edge g v1 v2 = 
    if not (HM.mem v2 g) then invalid_arg "remove_edge";
    HM.add v1 (S.remove v2 (HM.find_and_raise v1 g "remove_edge")) g

  let remove_edge_e g (v1, v2) = remove_edge g v1 v2

  let iter_succ f g v = S.iter f (HM.find_and_raise v g "iter_succ")
  let fold_succ f g v = S.fold f (HM.find_and_raise v g "fold_succ")

  let iter_succ_e f g v = iter_succ (fun v2 -> f (v, v2)) g v
  let fold_succ_e f g v = fold_succ (fun v2 -> f (v, v2)) g v

  let succ g v = S.elements (HM.find_and_raise v g "succ")
  let succ_e g v = fold_succ_e (fun e l -> e :: l) g v []

  let map_vertex f = 
    HM.map (fun v s -> f v, S.fold (fun v s -> S.add (f v) s) s S.empty)

  module I = struct
    type t = S.t HM.t
    module PV = V
    module PE = E
    let iter_edges f = HM.iter (fun v -> S.iter (f v))
    let fold_edges f = HM.fold (fun v -> S.fold (f v))
    let iter_edges_e f = iter_edges (fun v1 v2 -> f (v1, v2))
    let fold_edges_e f = fold_edges (fun v1 v2 a -> f (v1, v2) a)
  end
  include I

  include Pred(struct include I let mem_vertex = HM.mem end)

end

(* Common implementation to all the labeled (directed) graphs. *)
module Labeled(V: COMPARABLE)(E: ORDERED_TYPE)(HM: HM with type key = V.t) = 
struct
  
  module S = Set.Make(OTProduct(V)(E))
    
  module E = struct
    type vertex = V.t
    type label = E.t
    type t = vertex * label * vertex
    let src (v, _, _) = v
    let dst (_, _, v) = v
    let label (_, l, _) = l
    let create v1 l v2 = v1, l, v2
    module C = OTProduct(V)(OTProduct(E)(V))
    let compare (x1, x2, x3) (y1, y2, y3) = 
      C.compare (x1, (x2, x3)) (y1, (y2, y3))
  end

  let mem_edge g v1 v2 = 
    try
      S.exists (fun (v2', _) -> V.equal v2 v2') (HM.find v1 g)
    with Not_found ->
      false

  let mem_edge_e g (v1, _, v2) = mem_edge g v1 v2

  let unsafe_remove_edge g v1 v2 = 
    HM.add v1 (S.filter 
		 (fun (v2', _) -> not (V.equal v2 v2')) (HM.find v1 g)) g

  let unsafe_remove_edge_e g (v1, l, v2) = 
    HM.add v1 (S.remove (v2, l) (HM.find v1 g)) g

  let remove_edge g v1 v2 =
    if not (HM.mem v2 g) then invalid_arg "remove_edge";
    HM.add v1 (S.filter 
		 (fun (v2', _) -> not (V.equal v2 v2'))
		 (HM.find_and_raise v1 g "remove_edge")) g

  let remove_edge_e g (v1, l, v2) = 
    if not (HM.mem v2 g) then invalid_arg "remove_edge";
    HM.add v1 (S.remove (v2, l) (HM.find_and_raise v1 g "remove_edge_e")) g

  let iter_succ f g v = 
    S.iter (fun (w, _) -> f w) (HM.find_and_raise v g "iter_succ")
  let fold_succ f g v = 
    S.fold (fun (w, _) -> f w) (HM.find_and_raise v g "fold_succ")

  let iter_succ_e f g v = 
    S.iter (fun (w, l) -> f (v, l, w)) (HM.find_and_raise v g "iter_succ_e")
  let fold_succ_e f g v = 
    S.fold (fun (w, l) -> f (v, l, w)) (HM.find_and_raise v g "fold_succ_e")

  let succ g v = fold_succ (fun w l -> w :: l) g v []
  let succ_e g v = fold_succ_e (fun e l -> e :: l) g v []

  let map_vertex f = 
    HM.map (fun v s -> 
	      f v, S.fold (fun (v, l) s -> S.add (f v, l) s) s S.empty)
    
  module I = struct
    type t = S.t HM.t
    module PV = V
    module PE = E

    let iter_edges f = HM.iter (fun v -> S.iter (fun (w, _) -> f v w))
    let fold_edges f = HM.fold (fun v -> S.fold (fun (w, _) -> f v w))
    let iter_edges_e f = 
      HM.iter (fun v -> S.iter (fun (w, l) -> f (v, l, w)))
    let fold_edges_e f = 
      HM.fold (fun v -> S.fold (fun (w, l) -> f (v, l, w)))
  end
  include I

  include Pred(struct include I let mem_vertex = HM.mem end)

end

(* The vertex module and the vertex table for the concrete graphs. *)
module ConcreteVertex
  (F : functor(X: COMPARABLE) -> HM with type key = X.t)
  (V: COMPARABLE) = 
struct

  module V = struct
    include V
    type label = t
    let label v = v
    let create v = v
  end

  module HM = F(V)

end

(* Abstract [G]. *)
(* JS: factorisation de remove_edge impossible due  un bug caml;
   laisser le code en commentaire en attendant un fix *)
(*
module Make_Abstract(G: sig 
		       type return 
		       include Graph.S
		       val remove_edge : t -> V.t -> V.t -> return
		       val remove_edge_e : t -> E.t -> return
		     end) = 
struct
*)
module Make_Abstract(G: Sig.G) = struct

  module I = struct
    type t = { edges : G.t; mutable size : int }
	(* BE CAREFUL: [size] is only mutable in the imperative version.
	   As there is no extensible records in ocaml 3.07,
	   and for genericity reasons, [size] is mutable in both the
	   imperative and persistent implementation.
	   Do not modify size in the persistent implementation ! *)

    module PV = G.V
    module PE = struct include G.E type vertex = PV.t end

    let iter_edges f g = G.iter_edges f g.edges
    let fold_edges f g = G.fold_edges f g.edges
    let iter_edges_e f g = G.iter_edges_e f g.edges
    let fold_edges_e f g = G.fold_edges_e f g.edges
    let mem_vertex v g = G.mem_vertex g.edges v
  end
  include I

  include Pred(I)

  (* optimisations *)

  let is_empty g = g.size = 0
  let nb_vertex g = g.size
  let out_degree g v = G.out_degree g.edges v
  let in_degree g v = G.in_degree g.edges v

  (* redefinitions *)

  let nb_edges g = G.nb_edges g.edges
  let succ g = G.succ g.edges
  let mem_vertex g = G.mem_vertex g.edges
  let mem_edge g = G.mem_edge g.edges
  let mem_edge_e g = G.mem_edge_e g.edges

(* JS: factorisation de remove_edge impossible due  un bug caml;
   laisser le code en commentaire en attendant un fix *)
(*
  let remove_edge g = G.remove_edge g.edges
  let remove_edge_e g = G.remove_edge_e g.edges
*)
  let iter_vertex f g = G.iter_vertex f g.edges
  let fold_vertex f g = G.fold_vertex f g.edges
  let iter_succ f g = G.iter_succ f g.edges
  let fold_succ f g = G.fold_succ f g.edges
  let succ_e g = G.succ_e g.edges
  let iter_succ_e f g = G.iter_succ_e f g.edges
  let fold_succ_e f g = G.fold_succ_e f g.edges
  let map_vertex f g = { g with edges = G.map_vertex f g.edges }

end

(* Build persistent (resp. imperative) graphs from a persistent (resp. 
   imperative association table) *)
module Make(F : functor(X: COMPARABLE) -> HM with type key = X.t) = struct

  module Digraph = struct

    module Concrete(V: COMPARABLE) = struct

      include ConcreteVertex(F)(V)
      include Unlabeled(V)(HM)
      include Minimal(S)(HM)

    end

    module ConcreteLabeled(V: COMPARABLE)(E: ORDERED_TYPE_DFT) = struct

      let default = E.default

      include ConcreteVertex(F)(V)
      include Labeled(V)(E)(HM)
      include Minimal(S)(HM)

    end

    module Abstract(V: VERTEX) = struct

      module G = struct
	module V = V
	module HM = F(V)
	include Unlabeled(V)(HM) 
	include Minimal(S)(HM)
      end

      (* export some definitions of G *)
      module V = G.V
      module E = G.E
      module HM = G.HM
      module S = G.S
      let unsafe_add_vertex = G.unsafe_add_vertex
      let unsafe_add_edge = G.unsafe_add_edge
      let unsafe_remove_edge = G.unsafe_remove_edge
      let unsafe_remove_edge_e = G.unsafe_remove_edge_e
      let is_directed = G.is_directed
      let empty = G.empty
      let create = G.create

(* JS: factorisation de remove_edge impossible due  un bug caml;
   laisser le code en commentaire en attendant un fix *)
(*      
      include Make_Abstract(struct type return = S.t HM.return include G end)
*)
      include Make_Abstract(G)
      (* JS: 
	 2 lignes suivantes  supprimer quand la factorisation sera possible *)
      let remove_edge g = G.remove_edge g.edges
      let remove_edge_e g = G.remove_edge_e g.edges

    end

    module AbstractLabeled(V: VERTEX)(E: ORDERED_TYPE_DFT) = struct

      let default = E.default

      module G = struct
	module V = V
	module HM = F(V)
	include Labeled(V)(E)(HM) 
	include Minimal(S)(HM)
      end

      (* export some definitions of G *)
      module V = G.V
      module E = G.E
      module HM = G.HM
      module S = G.S
      let unsafe_add_vertex = G.unsafe_add_vertex
      let unsafe_add_edge = G.unsafe_add_edge
      let unsafe_remove_edge = G.unsafe_remove_edge
      let unsafe_remove_edge_e = G.unsafe_remove_edge_e
      let is_directed = G.is_directed
      let empty = G.empty
      let create = G.create

(* JS: factorisation de remove_edge impossible due  un bug caml;
   laisser le code en commentaire en attendant un fix *)
(*      
      include Make_Abstract(struct type return = S.t HM.return include G end)
*)
      include Make_Abstract(G)
      (* JS: 
	 2 lignes suivantes  supprimer quand la factorisation sera possible *)
      let remove_edge g = G.remove_edge g.edges
      let remove_edge_e g = G.remove_edge_e g.edges

    end

  end

end

module Graph = struct

  (* Predecessors are the successors in a undirected graph. *)
  module Pred(S: sig
		module PV : VERTEX
		module PE : EDGE
		type t
		val succ : t -> PV.t -> PV.t list
		val out_degree : t -> PV.t -> int
		val iter_succ : (PV.t -> unit) -> t -> PV.t -> unit
		val fold_succ : (PV.t -> 'a -> 'a) -> t -> PV.t -> 'a -> 'a
		val succ_e : t -> PV.t -> PE.t list
		val iter_succ_e : (PE.t -> unit) -> t -> PV.t -> unit
		val fold_succ_e : (PE.t -> 'a -> 'a) -> t -> PV.t -> 'a -> 'a
	      end) = 
  struct
    open S
    let pred = succ
    let in_degree = out_degree
    let iter_pred = iter_succ
    let fold_pred = fold_succ
    let pred_e = succ_e
    let iter_pred_e = iter_succ_e
    let fold_pred_e = fold_succ_e
  end

  (* Same functor for unlabeled and labeled concrete graphs *)
  module Concrete(G: Sig.G) = struct

    module M = struct

      include G
      module PV = V
      module PE = struct type vertex = V.t include E end
		    
      let is_directed = false
			  
      (* Optimise the edges iterators. *)
			  
      module H = Hashtbl.Make(HTProduct(G.V)(G.V))

      let iter_edges f =
	let h = H.create 97 in
	G.iter_edges 
	  (fun v1 v2 -> 
	     if not (H.mem h (v1, v2)) then begin 
	       H.add h (v2, v1) ();
	       f v1 v2
	     end) 
	    
      let fold_edges f =
	let h = H.create 97 in
	G.fold_edges 
	  (fun v1 v2 acc -> 
	     if not (H.mem h (v1, v2)) then begin 
	       H.add h (v2, v1) ();
	       f v1 v2 acc
	     end else
	       acc) 

      let iter_edges_e f =
	let h = H.create 97 in
	G.iter_edges_e
	  (fun e ->
	     let v1, v2 as vv = G.E.src e, G.E.dst e in
	     if not (H.mem h vv) then begin 
	       H.add h (v2, v1) ();
	       f e
	     end) 

      let fold_edges_e f =
	let h = H.create 97 in
	G.fold_edges_e
	  (fun e acc -> 
	     let v1, v2 as vv = G.E.src e, G.E.dst e in
	     if not (H.mem h vv) then begin 
	       H.add h (v2, v1) ();
	       f e acc
	     end else
	       acc)
	  
    end

    include M

    include Pred(M)

  end

  (* Same functor for unlabeled and labeled abstract graphs *)
  module Abstract(G: sig include Sig.G val tag : V.t -> int end) = struct

    module M = struct
      include G

      module PV = V
      module PE = struct type vertex = V.t include E end

      let is_directed = false

      (* Optimise the edges iterators. 
	 Better optimisations than the concrete ones. *)

      let iter_edges f =
	iter_edges (fun v1 v2 -> if tag v1 <= tag v2 then f v1 v2)
	  
      let fold_edges f =
	fold_edges 
	  (fun v1 v2 acc -> if tag v1 <= tag v2 then f v1 v2 acc else acc)

      let iter_edges_e f =
	iter_edges_e (fun e -> if tag (E.src e) <= tag (E.dst e) then f e)

      let fold_edges_e f =
	fold_edges_e 
	  (fun e acc -> 
	     if tag (E.src e) <= tag (E.dst e) then f e acc else acc)
	  
    end
    include M

    include Pred(M)

  end

end
