/*
  Logistic Regression using Truncated Iteratively Re-weighted Least Squares
  (includes several programs)
  Copyright (C) 2005  Paul Komarek

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with this program; if not, write to the Free Software
  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA

  Author: Paul Komarek, komarek@cmu.edu
  Alternate contact: Andrew Moore, awm@cs.cmu.edu

*/


/*
   File:        kfoldssumm.c
   Author:      Paul Komarek
   Created:     Mon May 23 18:21:32 EDT 2005
   Description: Data structure to Keep track of per-fold information.

   Copyright 2005, The Auton Lab
*/



#include "amma.h"
#include "amiv.h"
#include "amdyv.h"

#include "score.h"
#include "lrutils.h"
#include "predict.h"
#include "kfoldsumm.h"


/* Basic data structure operations. */
kfoldsumm *mk_kfoldsumm( int folds, int numrows)
{
  kfoldsumm *kfs;

  kfs = AM_MALLOC(kfoldsumm);

  kfs->fold_to_time = mk_constant_ivec( folds, -1);
  kfs->row_to_foldnum = mk_constant_ivec( numrows, -1);
  kfs->row_to_predict = mk_constant_dyv( numrows, -1.0);

  /* Stats, filled in after all runs. */
  kfs->fold_to_auc = NULL;
  kfs->roc_num_wrong = NULL;
  kfs->roc_num_right = NULL;
  kfs->auc_mean = -1.0;
  kfs->auc_sdev = -1.0;
  kfs->time_mean = -1.0;
  kfs->time_sdev = -1.0;

  return kfs;
}

kfoldsumm *mk_copy_kfoldsumm( const kfoldsumm *kfs)
{
  kfoldsumm *newkfs;

  newkfs = AM_MALLOC(kfoldsumm);
  newkfs->fold_to_time = mk_copy_ivec( kfs->fold_to_time);
  newkfs->row_to_foldnum = mk_copy_ivec( kfs->row_to_foldnum);
  newkfs->row_to_predict = mk_copy_dyv( kfs->row_to_predict);

  return newkfs;
}

void fprintf_kfoldsumm( FILE *f , const char *pre, const kfoldsumm *kfs,
                        const char *post)
{
  fprintf( f, "Hello, world\n");
}

void free_kfoldsumm( kfoldsumm *kfs)
{
  if (kfs != NULL) {
    if (kfs->fold_to_time != NULL) free_ivec( kfs->fold_to_time);
    if (kfs->row_to_foldnum != NULL) free_ivec( kfs->row_to_foldnum);
    if (kfs->row_to_predict != NULL) free_dyv( kfs->row_to_predict);

    if (kfs->fold_to_auc != NULL) free_dyv( kfs->fold_to_auc);
    if (kfs->roc_num_wrong != NULL) free_ivec( kfs->roc_num_wrong);
    if (kfs->roc_num_right != NULL) free_ivec( kfs->roc_num_right);

    AM_FREE( kfs, kfoldsumm);
  }

  return;
}



/* Simple setting operations. */
void kfoldsumm_set_fold_time( kfoldsumm *kfs, int fold, int seconds)
{
  ivec_set( kfs->fold_to_time, fold, seconds);
  return;
}

void kfoldsumm_set_subfoldnums( kfoldsumm *kfs, const ivec *subrows,
                                int fold)
{
  int size, i, row;
  size = ivec_size( subrows);
  for (i=0; i<size; ++i) {
    row = ivec_ref( subrows, i);
    ivec_set( kfs->row_to_foldnum, row, fold);
  }
  return;
}

void kfoldsumm_set_subpredicts( kfoldsumm *kfs, const ivec *subrows,
                                const dyv *subpredicts)
{
  int size, i, row;
  double val;

  size = ivec_size( subrows);
  for (i=0; i<size; ++i) {
    row = ivec_ref( subrows, i);
    val = dyv_ref( subpredicts, i);
    dyv_set( kfs->row_to_predict, row, val);
  }

  return;
}



/* Statistics. */
void kfoldsumm_update_stats( kfoldsumm *kfs, ivec *outputs)
{
  int folds, fold;
  double auc;
  ivec *subrows, *subouts;
  dyv *dv, *subpred;

  /* Overall auc. */
  kfs->total_auc = mk_roc_curve( outputs, kfs->row_to_predict,
                                 &(kfs->roc_num_wrong), &(kfs->roc_num_right));

  /* Per-fold auc. */
  folds = ivec_size( kfs->fold_to_time);
  kfs->fold_to_auc = mk_dyv( folds);

  for (fold=0; fold<folds; ++fold) {
    subrows = mk_find_all_in_ivec( kfs->row_to_foldnum, fold);
    subpred = mk_dyv_subset( kfs->row_to_predict, subrows);
    subouts = mk_ivec_subset( outputs, subrows);
    free_ivec( subrows);

    auc = mk_roc_curve( subouts, subpred, NULL, NULL);
    free_ivec( subouts);
    free_dyv( subpred);
    
    dyv_set( kfs->fold_to_auc, fold, auc);
  }

  /* auc summary. */
  kfs->auc_mean = dyv_mean( kfs->fold_to_auc);
  kfs->auc_sdev = dyv_sdev( kfs->fold_to_auc);

  /* Time summary. */
  dv = mk_dyv_from_ivec( kfs->fold_to_time);
  kfs->time_mean = dyv_mean( dv);
  kfs->time_sdev = dyv_sdev( dv);
  free_dyv( dv);

  return;
}

void fprintf_kfoldsumm_stats( FILE *f, const char *pre, const kfoldsumm *kfs)
{
  int fold, folds;

  folds = ivec_size( kfs->fold_to_time);

  /* Score. */
  fprintf( f, "%sTOTAL AUC = %g\n", pre, kfs->total_auc);
  fprintf( f, "%sFOLD AUC MEAN = %g\n", pre, kfs->auc_mean);
  fprintf( f, "%sFOLD AUC SDEV = %g\n", pre, kfs->auc_sdev);
  if (Verbosity >= 1) {
    for (fold=0; fold<folds; ++fold) {
	fprintf( f, "%sFOLD %d AUC = %g\n",
		 pre, fold, dyv_ref(kfs->fold_to_auc, fold));
    }
  }

  /* Time. */
  fprintf( f, "%sTOTAL ALG TIME = %d seconds\n",
           pre, ivec_sum( kfs->fold_to_time));
  fprintf( f, "%sFOLD ALG TIME MEAN = %g seconds\n", pre, kfs->time_mean);
  fprintf( f, "%sFOLD ALG TIME SDEV = %g seconds\n", pre, kfs->time_sdev);
  if (Verbosity >= 1) {
    for (fold=0; fold<folds; ++fold) {
      fprintf( f, "%sFOLD %d ALG TIME = %d\n",
               pre, fold, ivec_ref(kfs->fold_to_time, fold));
    }
  }

  return;
}



/* Save kfoldsumm information to a file. */
void kfoldsumm_save_foldnums( const kfoldsumm *kfs, const char *savename)
{
  PFILE *f;
  int numrows, i, fold;

  f = sure_pfopen( savename, "w");
  
  numrows = ivec_size( kfs->row_to_foldnum);
  for (i=0; i<numrows; ++i) {
    fold = ivec_ref( kfs->row_to_foldnum, i);
    pfprintf( f, "%d\n", fold);
  }

  sure_pfclose( f, savename);
  return;
}

void kfoldsumm_save_predictions( const kfoldsumm *kfs, const char *savename)
{
  predict_write_pout( savename, kfs->row_to_predict);
  return;
}

void kfoldsumm_save_roc( const kfoldsumm *kfs, const char *savename,
                         const char *datafname)
{
  /* Store ROC coordinates, along with summary information as comments. */
  /* If datafname is not NULL, then it will be printed as a comment
     about the dataset file. */
  PFILE *f;
  int numrows, numfolds, i, x, y;

  numrows = ivec_size( kfs->row_to_foldnum);
  numfolds = ivec_size( kfs->fold_to_time);

  f = sure_pfopen( savename, "w");
  pfprintf( f, "# ROC curve:\n");
  pfprintf( f, "# See the documentation that accompanies this software for\n");
  pfprintf( f, "# explanation.  The tuples below are x,y pairs, where\n");
  pfprintf( f, "# x=number wrong so far, and y=number right so far.\n");
  pfprintf( f, "#\n");
  pfflush( f);

  if (datafname != NULL) pfprintf( f, "# Data file: '%s'\n", datafname);
  pfprintf( f, "# Number of dataset rows: %d\n", numrows);
  pfprintf( f, "# Number of folds: %d\n", numfolds);
  pfprintf( f, "#\n");
  pfprintf( f, "# TOTAL AUC = %g\n", kfs->total_auc);
  pfprintf( f, "# FOLD AUC MEAN = %g\n", kfs->auc_mean);
  pfprintf( f, "# FOLD AUC SDEV = %g\n", kfs->auc_sdev);
  pfprintf( f, "# FOLD ALG TIME MEAN = %g seconds\n", kfs->time_mean);
  pfprintf( f, "# FOLD ALG TIME SDEV = %g seconds\n", kfs->time_sdev);

  pfprintf( f, "\n");
  pfflush( f);


  for (i=0; i<numrows; ++i) {
    x = ivec_ref( kfs->roc_num_wrong, i);
    y = ivec_ref( kfs->roc_num_right, i);
    pfprintf( f, "%d, %d\n", x, y);
  }

  sure_pfclose( f, savename);
  return;
}
