/*
  Logistic Regression using Truncated Iteratively Re-weighted Least Squares
  (includes several programs)
  Copyright (C) 2005  Paul Komarek

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with this program; if not, write to the Free Software
  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA

  Author: Paul Komarek, komarek@cmu.edu
  Alternate contact: Andrew Moore, awm@cs.cmu.edu

*/


/*
   File:        kfold.c
   Author:      Paul Komarek
   Created:     Thu Jun 12 03:30:15 EDT 2003
   Description: k-fold cross-validation.

   Copyright 2003, The Auton Lab, CMU
*/

#include <stdio.h>
#include <string.h>
#include <time.h>

#include "amiv.h"
#include "amdyv.h"
#include "amdym.h"
#include "spardat.h"

#include "lrutils.h"
#include "lr.h"
#include "train.h"
#include "predict.h"

#include "kfoldsumm.h"
#include "kfold.h"



/**************************************************************************/
/* KFOLD                                                                  */
/**************************************************************************/

dyv *mk_fold_predictions( const spardat *sp, const dym *factors,
                          const dyv *outputs, ivec *train_rows,
                          ivec *test_rows, lr_options *opts)
{
  int numtrain, numtest;
  dyv *preds, *trainout, *testout;
  spardat *trainsp, *testsp;
  dym *trainfactors, *testfactors;
  lr_predict *lrp;

  numtrain = ivec_size( train_rows);
  numtest = ivec_size( test_rows);

  trainsp = NULL;
  testsp = NULL;
  trainfactors = NULL;
  testfactors = NULL;
  trainout = NULL;
  testout = NULL;
  
  if (sp != NULL) {
    trainsp = mk_spardat_from_subset_of_rows( sp, train_rows);
    testsp = mk_spardat_from_subset_of_rows( sp, test_rows);
  }
  else {
    trainfactors = mk_dym_subset( factors, train_rows, NULL);
    testfactors = mk_dym_subset( factors, test_rows, NULL);
    trainout = mk_dyv_subset( outputs, train_rows);
    testout = mk_dyv_subset( outputs, test_rows);
  }

  lrp = mk_train_lr_predict( trainsp, trainfactors, trainout, opts);
  preds = mk_predicts( testsp, testfactors, testout, lrp);
  free_lr_predict( lrp);

  if (trainout != NULL) free_dyv( trainout);
  if (testout != NULL) free_dyv( testout);
  if (trainsp != NULL) free_spardat( trainsp);
  if (testsp != NULL) free_spardat( testsp);
  if (trainfactors != NULL) free_dym( trainfactors);
  if (testfactors != NULL) free_dym( testfactors);

  return preds;
}

void kfold_run_folds( kfoldsumm *kfs, const spardat *sp, const dym *factors,
                      const dyv *outputs, int folds, lr_options *opts)
{
  int fold, size;
  time_t start, stop;
  ivec *train_rows, *test_rows;
  dyv *predictions;

  if (sp != NULL) size = spardat_num_rows( sp);
  else size = dym_rows( factors);

  for (fold=0; fold<folds; ++fold) {
    /* Get list of train rows and test rows. */
    make_kfold_rows( NULL, size, folds, fold, &train_rows, &test_rows);
    
    /* Make predictions. */
    start =time( NULL);
    predictions = mk_fold_predictions( sp, factors, outputs,
                                       train_rows, test_rows, opts);
    stop = time( NULL);

    /* Done with training rows. */
    free_ivec(train_rows);

    /* Write predictions into data structure. */
    kfoldsumm_set_fold_time( kfs, fold, stop-start);
    kfoldsumm_set_subfoldnums( kfs, test_rows, fold);
    kfoldsumm_set_subpredicts( kfs, test_rows, predictions);

    /* Free remaining per-iteration stuff. */
    free_ivec(test_rows);
    free_dyv(predictions);
  }

  return;
}

void run_kfold( char *inname, int folds, char *pout, char *fout, char *rout,
                int argc, char **argv)
{
  int csv, numrows;
  ivec *outputs;
  dym *factors;
  dyv *dvoutputs;
  spardat *sp;
  lr_options *opts;
  kfoldsumm *kfs;

  csv = string_has_suffix( inname, ".csv");
  csv |= string_has_suffix( inname, ".csv.gz");

  /* Parse lr options. */
  opts = mk_lr_options();
  parse_lr_options( opts, argc, argv);
  check_lr_options( opts, argc, argv);

  /* Load full data file. */
  sp = NULL;
  factors = NULL;
  dvoutputs = NULL;
  if (!csv) {
    sp = mk_spardat_from_pfilename( inname, argc, argv);
    numrows = spardat_num_rows( sp);
    outputs = mk_copy_ivec( sp->row_to_outval);
  }
  else {
    mk_read_dym_for_csv( inname, &factors, &dvoutputs);
    if (!dyv_is_binary( dvoutputs)) {
      my_error( "run_kfold: Error: csv output column is not binary.\n");
    }
    numrows = dym_rows( factors);
    outputs = mk_ivec_from_dyv( dvoutputs);
  }

  /* Run folds. */
  kfs = mk_kfoldsumm( folds, numrows);
  kfold_run_folds( kfs, sp, factors, dvoutputs, folds, opts);

  if (sp != NULL) free_spardat( sp);
  if (factors != NULL) free_dym( factors);
  if (dvoutputs != NULL) free_dyv( dvoutputs);
  free_lr_options( opts);

  /* Done. */
  kfoldsumm_update_stats( kfs, outputs);
  free_ivec( outputs);
  fprintf_kfoldsumm_stats( stdout, "", kfs);

  if (fout != NULL) kfoldsumm_save_foldnums( kfs, fout);
  if (pout != NULL) kfoldsumm_save_predictions( kfs, pout);
  if (rout != NULL) kfoldsumm_save_roc( kfs, rout, inname);
  free_kfoldsumm( kfs);

  return;
}



/**************************************************************************/
/* MAIN                                                                   */
/**************************************************************************/

static void usage( char *progname)
{
  printf( "\n");
  printf( "Usage:\n");
  printf( "%s in <train_datafile> folds <int> [options]\n", progname);
  printf( "\n");
  printf( "Options:\n");
  printf( "  fout <filename>            Save fold assignments.\n");
  printf( "  pout <filename>            Save aggregate predictions.\n");
  printf( "  rout <filename>            Save ROC curve.\n");
  printf( "\n");
  printf( "This program runs a k-fold cross-validation on the specified\n");
  printf( "data file.  Statistics are reported at the end.  Use the\n");
  printf( "option verbosity <int> to increase the amount of information\n");
  printf( "printed at the end.\n");
  printf( "\n");
}

void kfold_main( int argc, char **argv)
{
  char *inname, *poutname, *foutname, *routname;
  int folds;

  inname   = string_from_args( "in", argc, argv, "");
  folds = int_from_args( "folds", argc, argv, -1);
  foutname   = string_from_args( "fout", argc, argv, "");
  poutname   = string_from_args( "pout", argc, argv, "");
  routname   = string_from_args( "rout", argc, argv, "");

  if (!strcmp( inname, "")) {
    fprintf( stderr, "\nkfold_main: no datafile was specified with 'in'.\n\n");
    usage( argv[0]);
    exit(-1);
  }
  if (folds<0) {
    fprintf( stderr,
	     "\nkfold_main: number of folds not specified, or specified but\n"
             "negative.\n\n");
    usage( argv[0]);
    exit(-1);
  }

  if (!strcmp( foutname, "")) foutname = NULL;
  if (!strcmp( poutname, "")) poutname = NULL;
  if (!strcmp( routname, "")) routname = NULL;


  run_kfold( inname, folds, poutname, foutname, routname, argc, argv);

  return;
}
