/* $Id: test_paillier.C,v 1.6 2006/02/25 02:11:36 mfreed Exp $ */

/*
 *
 * Copyright (C) 2005 Michael J. Freedman (mfreedman at alum.mit.edu)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#define USE_PCTR 0

#include "crypt_prot.h"
#include "crypt.h"
#include "paillier.h"
#include "bench.h"

u_int64_t etime;
u_int64_t dtime;
u_int64_t htime;

static const size_t repeat = 5;
static const size_t cnt    = 10;

void
do_paillier (paillier_priv &sk, crypt_ctext &ctext, str &ptext)
{
  u_int64_t tmp1, tmp2, tmp3;

  u_int len = rnd.getword () % (mpz_sizeinbase2 (&sk.n)/8 - 1);
  len = max (len, (u_int) 20);

  wmstr wmsg (len);
  rnd.getbytes (wmsg, len);
  ptext = wmsg;

  tmp1 = get_time ();
  
  if (!sk.encrypt (&ctext, ptext)) {
    strbuf sb;
    sb << "Encryption failed\n"
       << "  p  = " << sk.p << "\n"
       << "  q  = " << sk.q << "\n"
       << "msg1 = " << hexdump (ptext.cstr (), ptext.len ()) << "\n";
    panic << sb;
  }

  tmp2 = get_time ();
  str ptext2 = sk.decrypt (ctext, len);
  tmp3 = get_time ();
  
  if (!ptext2 || ptext != ptext2) {
    strbuf sb;
    sb << "Decryption failed\n"
       << "  p  = " << sk.p << "\n"
       << "  q  = " << sk.q << "\n"
       << "msg1 = " << hexdump (ptext.cstr (), ptext.len ()) << "\n";
    if (ptext2)
      sb << "msg2 = " << hexdump (ptext2.cstr (), ptext2.len ()) << "\n";
    panic << sb;
  }

  etime += (tmp2 - tmp1);  
  dtime += (tmp3 - tmp2);
}


void
test_paillier (paillier_priv &sk)
{
  u_int64_t tmp1, tmp2;
  for (size_t i = 0; i < cnt; i++) {

    crypt_ctext ctext1 (sk.ctext_type ());
    crypt_ctext ctext2 (sk.ctext_type ());
    str ptext1, ptext2;

    do_paillier (sk, ctext1, ptext1);
    do_paillier (sk, ctext2, ptext2);

    //bigint ptextc = pre_paillier (ptext1, sk.nbits) + pre_paillier (ptext2, sk.nbits);
    bigint ptextc = sk.pre_encrypt (ptext1) + sk.pre_encrypt (ptext2);

    // Test homomorphic encryption

    tmp1 = get_time ();
    crypt_ctext ctextc (sk.ctext_type ());
    sk.add (&ctextc, ctext1, ctext2);
    tmp2 = get_time ();

    u_int len = max (ptext1.len (), ptext2.len ()) + 1;
    str cres  = sk.decrypt (ctextc, len);
    //     str pres  = post_paillier (ptextc, len, sk.nbits);
    str pres  = sk.post_decrypt (ptextc, len);


    if (cres != pres)
      panic << "Homomorphic multiplication failed\n"
	    << "\n        msg1 = " << hexdump (ptext1.cstr (), ptext1.len ())
	    << "\n        msg2 = " << hexdump (ptext2.cstr (), ptext2.len ())
	    << "\n cipher comb = " << hexdump (cres.cstr (), cres.len ())
	    << "\n  plain comb = " << hexdump (pres.cstr (), pres.len ()) 
	    << "\n";

    htime += (tmp2 - tmp1);
  }
}

void
do_test (int vsz, int asz, bool opt_v, bool fast)
{
  etime = dtime = htime = 0;

  if (!opt_v) {
    vsz = 424 + rnd.getword () % 256;
    asz = 160 + rnd.getword () % 256;
  }

  for (size_t i = 0; i < repeat; i++) {
    if (fast) {
      paillier_priv sk = paillier_keygen (vsz, asz);
      test_paillier (sk);
    }
    else {
      paillier_priv sk = paillier_skeygen (vsz);
      test_paillier (sk);
    }
  }
  
  if (opt_v) {
    size_t tot  = repeat * cnt;
    size_t ttot = tot * 2;
    if (fast)
      warn ("Paillier cryptosystem with %d bit key [%d] (fast mode)\n", 
	    vsz, asz);
    else
      warn ("Paillier cryptosystem with %d bit key [%d] (normal mode)\n", 
	    vsz, asz);

    warn ("   Encrypted  %u messages in %" U64F "u " 
	  TIME_LABEL " per message\n", ttot, (etime / ttot));
    warn ("   Decrypted  %u messages in %" U64F "u " 
	  TIME_LABEL " per message\n", ttot, (dtime / ttot));
    warn ("   Homo-added %u messages in %" U64F "u " 
	  TIME_LABEL " per message\n", tot, (htime / tot));
  }
}


int
main (int argc, char **argv)
{
  bool opt_v  = false;
  bool dofast = true;
  int  vsz    = 1024;
  int  asz    = 160;

  for (int i=1; i < argc; i++) {
    if (!strcmp (argv[i], "-v"))
      opt_v = true;
    else if (!strcmp (argv[i], "-b")) {
      assert (argc > i+1);
      vsz = atoi (argv[i+1]);
      assert (vsz > 0);
    }
    else if (!strcmp (argv[i], "-a")) {
      assert (argc > i+1);
      asz = atoi (argv[i+1]);
      assert (asz > 0);
    }
    else if (!strcmp (argv[i], "-n"))
      dofast = false;
  }

  setprogname (argv[0]);
  random_update ();

  if (dofast)
    do_test (vsz, asz, opt_v, true);
  do_test (vsz, asz, opt_v, false);
  
  return 0;
}
