(* PostgreSQL database interface for mod_caml programs.
 * Copyright (C) 2003-2004 Merjis Ltd.
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the Free
 * Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 * $Id: dbi_postgresql.ml,v 1.13 2005/11/17 09:56:53 rwmj Exp $
 *)

open Printf

module Pg = Postgresql

(* PCRE regular expressions for parsing timestamps and intervals. *)
let re_timestamp =
  Pcre.regexp ~flags:[`EXTENDED]
    ("(?:(\\d\\d\\d\\d)-(\\d\\d)-(\\d\\d))     # date (YYYY-MM-DD)\n"^
     "\\s*                                     # space between date & time\n"^
     "(?:(\\d\\d):(\\d\\d)                     # HH:MM\n"^
     "   (?::(\\d\\d))?                        # optional :SS\n"^
     "   (?:\\.(\\d+))?                        # optional .microseconds\n"^
     "   (?:([+-])(\\d\\d))?                   # optional +/- offset UTC\n"^
     ")?")
let re_interval =
  Pcre.regexp ~flags:[`EXTENDED]
    ("(?:(\\d+)\\syears?)?                     # years\n"^
     "\\s*                                     # \n"^
     "(?:(\\d+)\\smons?)?                      # months\n"^
     "\\s*                                     # \n"^
     "(?:(\\d+)\\sdays?)?                      # days\n"^
     "\\s*                                     # \n"^
     "(?:(\\d\\d):(\\d\\d)                     # HH:MM\n"^
     "   (?::(\\d\\d))?                        # optional :SS\n"^
     ")?")

let string_of_timestamp (date, time) =
  match time.Dbi.timezone with
  | None ->
      sprintf "'%04d-%02d-%02d %02d:%02d:%02d.%d'"
        date.Dbi.year date.Dbi.month date.Dbi.day
        time.Dbi.hour time.Dbi.min time.Dbi.sec time.Dbi.microsec
  | Some t ->
      sprintf "'%04d-%02d-%02d %02d:%02d:%02d.%d%+03d'"
        date.Dbi.year date.Dbi.month date.Dbi.day
        time.Dbi.hour time.Dbi.min time.Dbi.sec time.Dbi.microsec t

let timestamp_of_string =
  let int_opt s = if s = "" then 0 else int_of_string s in
  fun str ->
    try
      let sub = Pcre.extract ~rex:re_timestamp str in
      ({ Dbi.year = int_of_string sub.(1);
         Dbi.month = int_of_string sub.(2);
         Dbi.day = int_of_string sub.(3)   },
       { Dbi.hour = int_of_string sub.(4);
         Dbi.min = int_of_string sub.(5);
         Dbi.sec = int_opt sub.(6);
         Dbi.microsec = int_opt sub.(7);
         Dbi.timezone =
           if sub.(9) = "" then None
           else Some (let tz = int_of_string sub.(9) in
                      if sub.(8) = "-" then -tz else tz);
       })
    with Not_found ->
      (* FIXME: improve this *)
      if str = "infinity" then
	({ Dbi.year = max_int;
	   Dbi.month = 12;
           Dbi.day = 31; },
	 { Dbi.hour = 23;  Dbi.min = 59;  Dbi.sec = 59;
           Dbi.microsec = 1000000;  Dbi.timezone = None })
      else if str = "infinity" then
	({ Dbi.year = min_int;
	   Dbi.month = 1;
           Dbi.day = 1; },
	 { Dbi.hour = 0;  Dbi.min = 0;  Dbi.sec = 0;
           Dbi.microsec = 0;  Dbi.timezone = None })
      else
	failwith ("timestamp_of_string: bad timestamp: " ^ str)

let string_of_interval (date, time) =
  sprintf "'%d years %d mons %d days %02d:%02d:%02d.%d'"
    date.Dbi.year date.Dbi.month date.Dbi.day
    time.Dbi.hour time.Dbi.min time.Dbi.sec time.Dbi.microsec

let interval_of_string =
  let int_opt s = if s = "" then 0 else int_of_string s in
  fun str ->
    try
      let sub = Pcre.extract ~rex:re_interval str in
      ({ Dbi.year = int_of_string sub.(1);
         Dbi.month = int_of_string sub.(2);
         Dbi.day = int_of_string sub.(3);   },
       { Dbi.hour = int_of_string sub.(4);
         Dbi.min = int_of_string sub.(5);
         Dbi.sec = int_opt sub.(6);
         Dbi.microsec = 0;
	 Dbi.timezone = None;
       })
    with
      Not_found -> failwith ("interval_of_string: bad interval: " ^ str)

let date_of_string s =
  Scanf.sscanf s "%d-%d-%d"
    (fun yyyy mm dd -> { Dbi.year = yyyy; Dbi.month = mm; Dbi.day = dd })

let time_of_string s =
  Scanf.sscanf s "%d:%d:%d"
    (fun h m s -> { Dbi.hour = h; Dbi.min = m; Dbi.sec = s;
                    Dbi.microsec = 0;  Dbi.timezone = None })

(* [encode_sql_t v] returns a string suitable for substitution of "?"
   in a SQL query. *)
let encode_sql_t = function
  | `Null -> "null"
  | `Int i ->
      (* As we use [`Int] for INTEGER and SMALLINT, we quote the
	 integer to delay type resolution and so allow the system to
	 use an index (postgres manual "8.1.1. Integer Types"). *)
      "'" ^ string_of_int i ^ "'"
  | `Int32 i -> "'" ^ Int32.to_string i ^ "'"
  | `Int64 i -> "'" ^ Int64.to_string i ^ "'"
  | `Float f -> string_of_float f
  | `String s -> Dbi.string_escaped s
  | `Bool b -> if b then "true" else "false"
  | `Bigint i -> Big_int.string_of_big_int i
  | `Decimal d -> Dbi.Decimal.to_string d
  | `Date d -> sprintf "'%04i-%02i-%02i'" d.Dbi.year d.Dbi.month d.Dbi.day
  | `Time t -> sprintf "'%02i:%02i:%02i'" t.Dbi.hour t.Dbi.min t.Dbi.sec
  | `Timestamp t -> string_of_timestamp t
  | `Interval i -> string_of_interval i
  | `Blob s -> Dbi.string_escaped s (* FIXME *)
  | `Binary s -> "'" ^ Pg.escape_bytea s ^ "'::bytea"
  | `Unknown s -> Dbi.string_escaped s



(* Markus now has symbolic names for the types, so we do not need to
   make the numeric conversions ourselves.  (Hard-coded OIDs from
   pg_type.h tell the type.) *)
let decode_sql_t is_null ty v =
  if is_null then `Null
  else begin
    match ty with
    | Pg.INT2
    | Pg.INT4    -> `Int (int_of_string v)
    | Pg.FLOAT4
    | Pg.FLOAT8  -> `Float (float_of_string v)
    | Pg.TEXT
    | Pg.BPCHAR
    | Pg.CHAR
    | Pg.VARCHAR -> `String v
    | Pg.BOOL    -> `Bool(v = "t")
    | Pg.INT8    -> `Int64 (Int64.of_string v)
    | Pg.NUMERIC -> `Decimal(Dbi.Decimal.of_string v)
    | Pg.DATE    -> `Date (date_of_string v)
    | Pg.ABSTIME
    | Pg.RELTIME
    | Pg.TIME
    | Pg.TIMETZ -> `Time (time_of_string v) (* FIXME: what about TZ? *)
    | Pg.TIMESTAMP
    | Pg.TIMESTAMPTZ -> `Timestamp (timestamp_of_string v)
    | Pg.TINTERVAL
    | Pg.INTERVAL -> `Interval (interval_of_string v)
    | Pg.BYTEA   -> `Binary (Pg.unescape_bytea v)
    | _ ->
	`Unknown v
  end


class statement dbh (conn : Pg.connection) in_transaction original_query =
  let query = Dbi.split_query original_query in
object (self)
  inherit Dbi.statement dbh

  val mutable tuples = None
  val mutable name_list = None
  val mutable next_tuple = 0
  val mutable ntuples = 0
  val mutable nfields = 0

  method execute args =
    if dbh#debug then (
      eprintf "Dbi_postgres: dbh %d: execute %s\n" dbh#id original_query;
      flush stderr
    );
    if dbh#closed then
      failwith "Dbi_postgres: executed called on a closed database handle.";
    (* Finish previous statement, if any. *)
    self#finish ();

    (* In transaction? If not we need to issue a BEGIN WORK command. *)
    if not !in_transaction then (
      (* So we don't go into an infinite recursion ... *)
      in_transaction := true;

      let sth = dbh#prepare_cached "BEGIN WORK" in
      sth#execute []
    );

    let query = (* substitute args *)
      Dbi.make_query "Dbi_postgres: execute called with wrong number of args."
        encode_sql_t query args in
    (* Send the query to the database. *)
    let res = conn#exec query in

    match res#status with
	Pg.Empty_query ->
	  ()
      | Pg.Command_ok ->
	  ()
      | Pg.Tuples_ok ->
	  tuples <- Some res;
	  name_list <- None;
	  next_tuple <- 0;
	  ntuples <- res#ntuples;
	  nfields <- res#nfields
      | Pg.Copy_out
      | Pg.Copy_in ->
	  failwith "XXX copyin/copyout not implemented"
      | Pg.Bad_response
      | Pg.Fatal_error ->
	  (* dbh#close (); -- used to do this, not a good idea *)
	  raise (Dbi.SQL_error (res#error))
      | Pg.Nonfatal_error ->
	  prerr_endline ("Dbi_postgres: non-fatal error: " ^ res#error)

  method fetch1 () =
    if dbh#debug then (
      eprintf "Dbi_postgres: dbh %d: fetch1\n" dbh#id;
      flush stderr
    );

    match tuples with
      None -> failwith "Dbi_postgres: call execute before calling fetch."
    | Some tuples ->
	if next_tuple >= ntuples then raise Not_found;

	(* Fetch each field in the tuple. *)
	let row =
	  let rec loop i =
	    if i < nfields then (
	      let field = (* FIXME: what about binary tuples?? *)
		decode_sql_t (tuples#getisnull next_tuple i)
		  (tuples#ftype i)
		  (tuples#getvalue next_tuple i) in
	      field :: loop (i+1)
	    ) else
	      []
	  in
	  loop 0 in

	next_tuple <- next_tuple + 1;
	row

  method names =
    match tuples with
    | None -> failwith "Dbi_postgres.statement#names"
    | Some tuples ->
        begin match name_list with
        | Some l -> l
        | None ->
            let rec loop acc i =
              if i < 0 then acc
              else loop (tuples#fname i :: acc) (i - 1) in
            let l = loop [] (nfields - 1) in
            name_list <- Some l;
            l
        end

  method serial seq =
    if dbh#debug then (
      eprintf "Dbi_postgres: dbh %d: serial \"%s\"\n" dbh#id seq;
      flush stderr
    );

    let sth = dbh#prepare_cached "select currval (?)" in
    sth#execute [`String seq];
    match sth#fetch1 () with
    | [`Int serial] -> Int64.of_int serial
    | [`Int64 serial] -> serial
    | [`Null] | [] -> raise Not_found
    | xs -> invalid_arg ("unknown type returned from select currval: " ^
			   Dbi.sdebug xs)

  method finish () =
    if dbh#debug then (
      eprintf "Dbi_postgres: dbh %d: finish %s\n" dbh#id original_query;
      flush stderr
    );

    (match tuples with
       None -> ()
     | Some tuples ->
	 (* XXX PQclear is not exposed through Postgres library! *)
	 ());
    tuples <- None

end

and connection ?host ?port ?user ?password database =

  (* XXX Not sure if this allows you to pass arbitrary conninfo stuff in the
   * database field. It should do. Otherwise we should use an assoc list
   * to pass arbitrary parameters to the underlying database.
   *)
  let conn =
    new Pg.connection ?host ?port ?user ?password ~dbname:database () in

  (* We pass this reference around to the statement class so that all
   * statements belonging to this connection can keep track of our
   * transaction state and issue the appropriate BEGIN WORK command at
   * the right time.
   *)
  let in_transaction = ref false in

object (self)
  inherit Dbi.connection ?host ?port ?user ?password database as super

  method host = Some conn#host
  method port = Some conn#port
  method user = Some conn#user
  method password = Some conn#pass
  method database = conn#db

  method database_type = "postgres"

  method prepare query =
    if self#debug then (
      eprintf "Dbi_postgres: dbh %d: prepare %s\n" self#id query;
      flush stderr
    );

    if self#closed then
      failwith "Dbi_postgres: prepare called on closed database handle.";
    new statement
      (self : #Dbi.connection :> Dbi.connection)
      conn in_transaction query

  method commit () =
    super#commit ();
    let sth = self#prepare_cached "commit work" in
    sth#execute [];
    in_transaction := false

  method rollback () =
    let sth = self#prepare_cached "rollback work" in
    sth#execute [];
    in_transaction := false;
    super#rollback ()

  method close () =
    conn#finish;
    super#close ()

  initializer
    if conn#status = Pg.Bad then
      raise (Dbi.SQL_error conn#error_message)
end

let connect ?host ?port ?user ?password database =
  new connection ?host ?port ?user ?password database
let close (dbh : connection) = dbh#close ()
let closed (dbh : connection) = dbh#closed
let commit (dbh : connection) = dbh#commit ()
let ping (dbh : connection) = dbh#ping ()
let rollback (dbh : connection) = dbh#rollback ()
