/*
  Logistic Regression using Truncated Iteratively Re-weighted Least Squares
  (includes several programs)
  Copyright (C) 2005  Paul Komarek

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with this program; if not, write to the Free Software
  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA

  Author: Paul Komarek, komarek@cmu.edu
  Alternate contact: Andrew Moore, awm@cs.cmu.edu

*/


/*
   File:        predict.c
   Author:      Paul Komarek
   Created:     Thu Jun 12 03:30:15 EDT 2003
   Description: Trains and stores, or restores and predicts.

   Copyright 2003, The Auton Lab, CMU
*/

#include <stdio.h>
#include <string.h>

#include "amiv.h"
#include "amdyv.h"

#include "spardat.h"

#include "lrutils.h"
#include "score.h"
#include "lr.h"

#include "predict.h"


/**************************************************************************/
/* EXPORT                                                                 */
/**************************************************************************/

//rr -- 18/02/2006 -- exportation de la procdure de prdiction pour les donnes denses
dyv* __stdcall dense_mk_predicts(dym *testfactors, dyv *testout, lr_predict *lrp)
{
 return mk_predicts(NULL,testfactors,testout,lrp);
}

/**************************************************************************/
/* TEST                                                                   */
/**************************************************************************/

lr_predict *mk_predict_load_lr_predict( char *loadname)
{
  PFILE *f;
  lr_predict *lrp;

  f = sure_pfopen( loadname, "rb");
  lrp = mk_in_lr_predict( f);
  pfclose( f);

  return lrp;
}

dyv *mk_predicts( spardat *testsp, dym *testfactors, dyv *testout,
                  lr_predict *lrp)
{
  /* Returns dyv of predictions in same order as dataset. */
  int i, numrows;
  double prob;
  ivec *atts;
  dyv *attvals, *predicts;

  /* Load testset and run tests. */
  if (testsp != NULL) {
    numrows = ivec_array_size( testsp->row_to_posatts);
    predicts = mk_dyv( numrows);
    for (i=0; i<numrows; ++i) {
      atts = ivec_array_ref( testsp->row_to_posatts, i);
      prob = lr_predict_predict( atts, NULL, lrp);
      dyv_set( predicts, i, prob);
    }
  }
  else {
    numrows = dym_rows( testfactors);
    predicts = mk_dyv( numrows);
    for (i=0; i<numrows; ++i) {
      attvals = mk_dyv_from_dym_row( testfactors, i);
      prob = lr_predict_predict( NULL, attvals, lrp);
      free_dyv( attvals);
      dyv_set( predicts, i, prob);
    }
  }

  return predicts;
}

dyv *mk_predicts_from_filename( char *fname, lr_predict *lrp,
                                int argc, char **argv,
                                ivec **outputs, double *elapsed)
{
  /* Returns dyv of predictions in same order as dataset. */
  /* If outputs != NULL, an ivec of true output values is created. */
  int csv;
  time_t start, stop;
  dyv *predicts, *testout;
  spardat *testsp;
  dym *testfactors;

  csv = string_has_suffix( fname, ".csv");
  csv |= string_has_suffix( fname, ".csv.gz");

  /* Load testset and run tests. */
  if (!csv) {
    testsp = mk_spardat_from_pfilename( fname, argc, argv);
    if (outputs != NULL) *outputs = mk_copy_ivec( testsp->row_to_outval);

    start = time( NULL);
    predicts = mk_predicts( testsp, NULL, NULL, lrp);
    stop = time( NULL);

    free_spardat( testsp);
  }
  else {
    mk_read_dym_for_csv( fname, &testfactors, &testout);
    if (outputs != NULL) *outputs = mk_ivec_from_dyv( testout);

    start = time( NULL);
    predicts = mk_predicts( NULL, testfactors, testout, lrp);
    stop = time( NULL);
    
    free_dyv( testout);
    free_dym( testfactors);
  }

  if (elapsed != NULL) *elapsed = difftime( stop, start);
  
  return predicts;
}

void predict_write_pout( const char *poutname, const dyv *predicts)
{
  /* Write predictions sequentially to a file. */
  PFILE *f;
  int numrows, i;
  double p;

  f = sure_pfopen( poutname, "w");
  numrows = dyv_size( predicts);
  for (i=0; i<numrows; ++i) {
    p = dyv_ref( predicts, i);
    pfprintf( f, "%f\n", p);
  }
  sure_pfclose( f, poutname);

  return;
}

void predict_write_roc( const char *savename, const char *datafname,
                        double auc, const ivec *rocx, const ivec *rocy)
{
  /* Store ROC coordinates, along with summary information as comments. */
  /* If datafname is not NULL, then it will be printed as a comment
     about the dataset file. */
  PFILE *f;
  int numrows, i, x, y;

  numrows = ivec_size( rocx);

  f = sure_pfopen( savename, "w");
  pfprintf( f, "# ROC curve:\n");
  pfprintf( f, "# See the documentation that accompanies this software for\n");
  pfprintf( f, "# explanation.  The tuples below are x,y pairs, where\n");
  pfprintf( f, "# x=number wrong so far, and y=number right so far.\n");
  pfprintf( f, "#\n");
  pfflush( f);

  if (datafname != NULL) pfprintf( f, "# Data file: '%s'\n", datafname);
  pfprintf( f, "# Number of dataset rows: %d\n", numrows);
  pfprintf( f, "# AUC = %g\n", auc);
  pfprintf( f, "\n");
  pfflush( f);

  for (i=0; i<numrows; ++i) {
    x = ivec_ref( rocx, i);
    y = ivec_ref( rocy, i);
    pfprintf( f, "%d, %d\n", x, y);
  }

  sure_pfclose( f, savename);
  return;
}


void run_predict( char *inname, char *loadname, char *poutname,
                  char *routname, int argc, char **argv)
{
  double aucval, elapsed;
  ivec *outputs, *rocx, *rocy;
  dyv *predicts;
  lr_predict *lrp;

  /* Load lr_predict. */
  lrp = mk_predict_load_lr_predict( loadname);

  /* Load testset and run tests. */
  predicts = mk_predicts_from_filename( inname, lrp, argc, argv, &outputs,
                                        &elapsed);
  free_lr_predict( lrp);

  /* Compute auc. */
  aucval = mk_roc_curve( outputs, predicts, &rocx, &rocy);
  printf( "TOTAL AUC = %g\n", aucval);

  /* Write predicts to file. */
  if (poutname != NULL) predict_write_pout( poutname, predicts);
  if (routname != NULL) predict_write_roc( routname, inname, aucval,
                                           rocx, rocy);
  free_ivec( rocy);
  free_ivec( rocx);

  /* Clean. */
  free_ivec( outputs);
  free_dyv( predicts);

  /* Output stuff if desired. */
  if (Verbosity >= 1) {
    printf( "TOTAL ALG TIME = %f seconds\n", elapsed);
  }

  return;
}


/**************************************************************************/
/* MAIN                                                                   */
/**************************************************************************/

static void usage( char *progname)
{
  printf( "\n");
  printf( "Usage:\n");
  printf( "%s in <predict_datafile> load <filename> [options]\n", progname);
  printf( "\n");
  printf( "Options:\n");
  printf( "  pout <filename>            Save aggregate predictions.\n");
  printf( "  rout <filename>            Save ROC curve.\n");
  printf( "\n");
  printf( "Use train to learn from a dataset, and predict to compute\n");
  printf( "predictions on a dataset.  This program will print out the AUC\n");
  printf( "when finished.\n");
  printf( "\n");
  return;
}

void predict_main( int argc, char **argv)
{
  char *inname, *loadname, *poutname, *routname;

  inname   = string_from_args( "in", argc, argv, "");
  loadname = string_from_args( "load", argc, argv, "");
  poutname = string_from_args( "pout", argc, argv, "");
  routname = string_from_args( "rout", argc, argv, "");

  if (!strcmp( inname, "")) {
    fprintf( stderr, "\npredict_main: no datafile was specified with "
             "'in'.\n\n");
    usage( argv[0]);
    exit(-1);
  }
  if (!strcmp( loadname, "")) {
    fprintf( stderr, "\npredict_main: no load filename was specified with "
	     "'load'.\n\n");
    usage( argv[0]);
    exit(-1);
  }

  if (!strcmp( poutname, "")) poutname = NULL;
  if (!strcmp( routname, "")) routname = NULL;

  run_predict( inname, loadname, poutname, routname, argc, argv);

  return;
}
