(****************************************************************************)
(* UCompSpvAssesCrossValidation.pas - Copyright (c) 2004 Ricco RAKOTOMALALA *)
(****************************************************************************)

{
@abstract(Multiple cross validation)
@author(Ricco)
@created(12/01/2004)
Deux paramtres : nombre de folds et nombre de rptitions, si l'on met ce dernier  1 on
retrouve la validation croise standard.
}
unit UCompSpvAssesCrossValidation;

interface

USES
        Forms,Classes,IniFiles,
        UCompDefinition,
        UCompSpvAssesDefinition,
        UOperatorDefinition;

TYPE
        {gnrateur}
        TMLGenCompAssesCV = class(TMLGenComp)
                            protected
                            procedure   GenCompInitializations(); override;
                            public
                            function    GetClassMLComponent: TClassMLComponent; override;
                            end;

        {composant}
        TMLCompAssesCV =  class(TMLCompSpvAsses)
                          protected
                          function    getClassOperator: TClassOperator; override;
                          end;

        {oprateur}
        TOpAssesCV =  class(TOpSpvAsses)
                      private
                      {valeurs extrmes}
                      FMin,FMax: double;
                      {tableau des taux d'erreur en test}
                      FTabTestErr: array of double;
                      protected
                      procedure   ReInitialize(); override;
                      procedure   PrepareConfMatrix(); override;
                      function    getClassParameter: TClassOperatorParameter; override;
                      procedure   AssesExecution(); override;
                      public
                      function    getHTMLResultsSummary(): string; override;
                      end;

        {paramtres de l'oprateur}
        TOpPrmAssesCV = class(TOpPrmSpvAsses)
                        private
                        {nombre de rptitions}
                        FNbRepetitions: integer;
                        {nombre de folds pour une session}
                        FNbFolds: integer;
                        protected
                        function    CreateDlgParameters(): TForm; override;
                        procedure   SetDefaultParameters(); override;
                        public
                        procedure   LoadFromStream(prmStream: TStream); override;
                        procedure   SaveToStream(prmStream: TStream); override;
                        procedure   LoadFromINI(prmSection: string; prmINI: TMemIniFile); override;
                        procedure   SaveToINI(prmSection: string; prmINI: TMemIniFile); override;
                        function    getHTMLParameters(): string; override;
                        property    NbRepetitions: integer read FNbRepetitions write FNbRepetitions;
                        property    NbFolds: integer read FNbFolds write FNbFolds;
                        end;

implementation

uses
        Sysutils,
        UStringsResources, UDatasetExamples, UCompSpvLDefinition,
        UDatasetImplementation, UDatasetDefinition, UConstConfiguration,
  UDlgOpPrmSpvAssesCV, ULogFile, UCalcRndGenerator;

{ TMLGenCompAssesCV }

procedure TMLGenCompAssesCV.GenCompInitializations;
begin
 FMLComp:= mlcSpvAssessment;
 //FMLNumIcon:= 31;
 //FMLCompName:= str_comp_name_spvl_asses_cv;
 //FMLBitmapFileName:= 'MLSpvAssesCrossValidation.bmp';
end;

function TMLGenCompAssesCV.GetClassMLComponent: TClassMLComponent;
begin
 result:= TMLCompAssesCV;
end;

{ TMLCompAssesCV }

function TMLCompAssesCV.getClassOperator: TClassOperator;
begin
 result:= TOpAssesCV;
end;


{ TOpPrmAssesCV }

function TOpPrmAssesCV.CreateDlgParameters: TForm;
begin
 result:= TDlgOpPrmSpvAssesCV.CreateFromOpPrm(self);
end;

function TOpPrmAssesCV.getHTMLParameters: string;
var s: string;
begin
 s:= HTML_HEADER_TABLE_RESULT;
 s:= s+HTML_TABLE_COLOR_HEADER_GRAY+'<TH colspan=2>Cross-validation parameters</TH></TR>';
 s:= s+HTML_TABLE_COLOR_DATA_GRAY+format('<TD>Folds</TD><TD align="right">%d</TD></TR>',[FNbFolds]);
 s:= s+HTML_TABLE_COLOR_DATA_GRAY+format('<TD>Trials</TD><TD align="right">%d</TD></TR>',[FNbRepetitions]);
 s:= s+'</table>';
 result:= s;
end;

procedure TOpPrmAssesCV.LoadFromINI(prmSection: string;
  prmINI: TMemIniFile);
begin
 inherited;
 FNbRepetitions:= prmINI.ReadInteger(prmSection,'nb_repetitions',FNbRepetitions);
 FNbFolds:= prmINI.ReadInteger(prmSection,'nb_folds',FNbFolds);
end;

procedure TOpPrmAssesCV.LoadFromStream(prmStream: TStream);
begin
 inherited;
 prmStream.ReadBuffer(FNbRepetitions,sizeof(FNbRepetitions));
 prmStream.ReadBuffer(FNbFolds,sizeof(FNbFolds));
end;

procedure TOpPrmAssesCV.SaveToINI(prmSection: string; prmINI: TMemIniFile);
begin
 inherited;
 prmINI.WriteInteger(prmSection,'nb_repetitions',FNbRepetitions);
 prmINI.WriteInteger(prmSection,'nb_folds',FNbFolds);
end;

procedure TOpPrmAssesCV.SaveToStream(prmStream: TStream);
begin
 inherited;
 prmStream.WriteBuffer(FNbRepetitions,sizeof(FNbRepetitions));
 prmStream.WriteBuffer(FNbFolds,sizeof(FNbFolds));
end;

procedure TOpPrmAssesCV.SetDefaultParameters;
begin
 inherited;
 FNbRepetitions:= 5;
 FNbFolds:= 2;
end;

{ TOpAssesCV }

procedure TOpAssesCV.AssesExecution;
var testSample,sample: TExamples;
    r,f,i: integer;
    nbFolds,nbExamples,foldSize: integer;
    bb,bh: integer;
    tmpCM: TConfusionMatrix;
    kClass,kPred: TTypeDiscrete;
    ClassAtt: TAttribute;
    example: integer;
    tmpErr: double;
begin
 FMin:= 1.0e308;
 FMax:= -1.0e308;
 //attribut  prdire
 ClassAtt:= CompMetaSpv.OutputData.LstAtts[asTarget].Attribute[0];
 //folds
 nbFolds:= (prmOP as TOpPrmAssesCV).NbFolds;
 nbExamples:= AllExamples.Size;
 foldSize:= nbExamples div nbFolds;
 //matrice de confusion
 tmpCM:= TConfusionMatrix.createStructure(ClassAtt);
 //les individus temporaires
 sample:= TExamples.Create(AllExamples.Size);
 sample.Copy(AllExamples);
 testSample:= TExamples.Create(AllExamples.Size);
 //pour chaque rptition
 for r:= 1 to (prmOP as TOpPrmAssesCV).NbRepetitions do
  begin
   //modifier alatoirement l'ordre des individus
   sample.procRandomizeExamples(seedRandom);
   //initialiser la matrice de confusion
   tmpCM.CrossTab.ReInitialization();
   //pour chaque fold
   for f:= 1 to nbFolds do
    begin
     //construire la liste des individus
     bb:= pred(f)*foldSize;
     bh:= f*foldSize+1;
     self.RootExamples.BeginAdd();
     testSample.BeginAdd();
     for i:= 1 to sample.Size do
      begin
       if (i>bb) and (i<bh)
        then testSample.AddExample(sample.Number[i])
        else self.RootExamples.AddExample(sample.Number[i]);
      end;
     self.RootExamples.EndAdd();
     testSample.EndAdd();
     //lancer l'apprentissage en forant le retour arrire
     CompMetaSpv.Execute(TRUE);
     //calculer la matrice de confusion sur les "test sample"
     tmpErr:= 0.0;
     for i:= 1 to testSample.Size do
      begin
       example:= testSample.Number[i];
       kClass:= ClassAtt.dValue[example];
       kPred:= CompMetaSpv.PredClass.dValue[example];
       //mj du tableau
       tmpCM.CrossTab.Value[kClass,kPred]:= tmpCM.CrossTab.Value[kClass,kPred]+1;
       //les marges
       tmpCM.CrossTab.Value[kClass,0]:= tmpCM.CrossTab.Value[kClass,0]+1;
       tmpCM.CrossTab.Value[0,kPred]:= tmpCM.CrossTab.Value[0,kPred]+1;
       tmpCM.CrossTab.Value[0,0]:= tmpCM.CrossTab.Value[0,0]+1;
       //puis petit test local
       if (kClass<>kPred)
        then tmpErr:= tmpErr+1.0;
      end;
      tmpErr:= tmpErr/(1.0*testSample.Size);
     //envoyer la taille des fichiers et le taux d'erreur
     TraceLog.WriteToLogFile(format('CV (rep:%d,fold:%d) -> %d train and %d test >>> %.4f',[r,f,self.RootExamples.Size,testSample.Size,tmpErr]));
    end;
   //rcuprer l'erreur pour cette portion
   FTabTestErr[r]:= tmpCM.getErrorRate();
   //rcuprer les extrmes
   if (FTabTestErr[r]<FMin) then FMin:= FTabTestErr[r];
   if (FTabTestErr[r]>FMax) then FMax:= FTabTestErr[r];
   //ajouter la matrice de confusion
   ConfMatrixAsses.addOtherConfMatrix(tmpCM);
  end;
 tmpCM.Free; 
 testSample.Free;
 sample.Free;
end;

function TOpAssesCV.getClassParameter: TClassOperatorParameter;
begin
 result:= TOpPrmAssesCV;
end;

function TOpAssesCV.getHTMLResultsSummary: string;
var s: string;
    r: integer;
begin
 s:= '<P><B>CV error rate</B><BR>';
 s:= s+HTML_HEADER_TABLE_RESULT;
 s:= s+HTML_TABLE_COLOR_HEADER_GRAY+'<TH colspan="2">Range</TH></TR>';
 s:= s+HTML_TABLE_COLOR_DATA_BLUE+format('<TD>MIN</TD><TD align=right>%.4f</TD></TR>',[FMin]);
 s:= s+HTML_TABLE_COLOR_DATA_BLUE+format('<TD>MAX</TD><TD align=right>%.4f</TD></TR>',[FMax]);
 s:= s+HTML_TABLE_COLOR_HEADER_GRAY+'<TH>Trial</TH><TH>Err rate</TH></TR>';
 for r:= 1 to (prmOP as TOpPrmAssesCV).NbRepetitions do
  s:= s+HTML_TABLE_COLOR_DATA_GRAY+format('<TD>%d</TD><TD align=right>%.4f</TD></TR>',[r,FTabTestErr[r]]);
 s:= s+'</table>';

 s:= s+'<P><B>Overall cross-validation error rate</B><BR>';
 s:= s+ConfMatrixAsses.getHTMLResults();

 result:= s;
end;

procedure TOpAssesCV.PrepareConfMatrix;
begin
 inherited PrepareConfMatrix();
 setLength(FTabTestErr,succ((PrmOp as TOpPrmAssesCV).NbRepetitions));
end;

procedure TOpAssesCV.ReInitialize;
begin
 inherited ReInitialize();
 setLength(FTabTestErr,0);
end;

initialization
 RegisterClass(TMLGenCompAssesCV);
end.
