/*
  Logistic Regression using Truncated Iteratively Re-weighted Least Squares
  (includes several programs)
  Copyright (C) 2005  Paul Komarek

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with this program; if not, write to the Free Software
  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA

  Author: Paul Komarek, komarek@cmu.edu
  Alternate contact: Andrew Moore, awm@cs.cmu.edu

*/


/*
   File:        train.c
   Author:      Paul Komarek
   Created:     Thu Jun 12 03:30:15 EDT 2003
   Description: Trains and stores, or restores and predicts.

   Copyright 2003, The Auton Lab, CMU
*/

#include <stdio.h>
#include <string.h>

#include "amiv.h"
#include "amdyv.h"
#include "amdym.h"
#include "spardat.h"

#include "file.h"
#include "lr.h"

#include "train.h"

/**************************************************************************/
/* EXPORT                                                                 */
/**************************************************************************/

//rr -- 18/02/2006 -- exportation de la procdure d'apprentissage pour les donnes denses
lr_predict* __stdcall dense_mk_train_lr_predict(dym *factors, dyv *outputs, lr_options *opts)
{
 //mettre les sparse-data en NULL d'office.... on n'utilise que les donnes standard ici
 return mk_train_lr_predict(NULL,factors,outputs,opts);
}

/**************************************************************************/
/* TRAIN                                                                  */
/**************************************************************************/

void train_save( char *savename, lr_predict *lrp)
{
  PFILE *f;
  f = sure_pfopen( savename, "w");
  out_lr_predict( f, lrp);
  sure_pfclose( f, savename);
  return;
}

lr_predict* mk_train_lr_predict( spardat *sp, dym *factors, dyv *outputs, lr_options *opts)
{
  lr_train *lrt;
  lr_predict *lrp;
  lrt = mk_lr_train( sp, factors, outputs, NULL, opts);
  lrp = mk_lr_predict( lrt_b0_ref(lrt), lrt_b_ref(lrt));
  free_lr_train( lrt);
  return lrp;
}

void run_train( char *inname, char *savename, int argc, char **argv)
{
  int csv;
  time_t start, stop;
  dyv *outputs;
  dym *factors;
  spardat *sp;
  lr_options *opts;
  lr_predict *lrp;

  csv = string_has_suffix( inname, ".csv");
  csv |= string_has_suffix( inname, ".csv.gz");

  /* Parse lr options. */
  opts = mk_lr_options();
  parse_lr_options( opts, argc, argv);
  check_lr_options( opts, argc, argv);

  /* Load data. */
  sp = NULL;
  factors = NULL;
  outputs = NULL;
  if (!csv) sp = mk_spardat_from_pfilename( inname, argc, argv);
  else {
    mk_read_dym_for_csv( inname, &factors, &outputs);
    if (!dyv_is_binary( outputs)) {
      my_error( "run_train: Error: csv output column is not binary.\n");
    }
  }

  /* Train. */
  start = time( NULL);
  lrp = mk_train_lr_predict( sp, factors, outputs, opts);
  stop = time( NULL);

  /* Clean. */
  if (sp != NULL) free_spardat( sp);
  if (factors != NULL) free_dym( factors);
  if (outputs != NULL) free_dyv( outputs);
  free_lr_options( opts);

  /* Save enough info to make predictions later. */
  train_save( savename, lrp);
  free_lr_predict( lrp);

  /* Output stuff if desired. */
  if (Verbosity >= 1) {
    printf( "TOTAL ALG TIME = %f seconds\n", difftime(stop, start));
  }
  
  return;
}

/**************************************************************************/
/* MAIN                                                                   */
/**************************************************************************/

static void usage( char *progname)
{
  printf( "\n");
  printf( "Usage:\n");
  printf( "%s in <train_datafile> ", progname);
  printf( "save <filename>\n");
  printf( "\n");
  printf( "Use train to learn from a dataset, and predict to compute ");
  printf( "predictions on a\ndataset.  predict will print out the AUC ");
  printf( "when finished.");
  printf( "\n");
  return;
}

void train_main( int argc, char **argv)
{
  char *inname, *savename;

  inname   = string_from_args( "in", argc, argv, "");
  savename = string_from_args( "save", argc, argv, "");

  if (!strcmp( inname, "")) {
    fprintf( stderr, "\ntrain_main: no datafile was specified with 'in'.\n\n");
    usage( argv[0]);
    exit(-1);
  }
  if (!strcmp( savename, "")) {
    fprintf( stderr,
	     "\ntrain_main: no save filename was specified with 'save'.\n\n");
    usage( argv[0]);
    exit(-1);
  }

  run_train( inname, savename, argc, argv);

  return;
}
