Vérification des versions¶

In [ ]:
#version de python
import sys
print(sys.version)
3.11.9 | packaged by Anaconda, Inc. | (main, Apr 19 2024, 16:40:41) [MSC v.1916 64 bit (AMD64)]
In [ ]:
#version de pytorch
import torch
print(torch.__version__)
2.2.0

Importation et préparation des données¶

Préparation usuelle des données

In [ ]:
#modifier le dossier de travail
import os
os.chdir("C:/Users/ricco/Desktop/demo")

#importer les données
import pandas
pdTrain = pandas.read_excel("breast_train.xlsx",header=0)
print(pdTrain.info())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 399 entries, 0 to 398
Data columns (total 7 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   ucellsize   399 non-null    int64 
 1   ucellshape  399 non-null    int64 
 2   mgadhesion  399 non-null    int64 
 3   sepics      399 non-null    int64 
 4   normnucl    399 non-null    int64 
 5   mitoses     399 non-null    int64 
 6   classe      399 non-null    object
dtypes: int64(6), object(1)
memory usage: 21.9+ KB
None
In [ ]:
#premières lignes
pdTrain.head()
Out[ ]:
ucellsize ucellshape mgadhesion sepics normnucl mitoses classe
0 1 1 1 2 1 1 begnin
1 3 2 1 3 6 1 begnin
2 9 7 3 4 7 1 malignant
3 10 10 7 10 2 1 malignant
4 8 8 4 10 1 1 malignant
In [ ]:
#structure avec les descripteurs
XTrain = pdTrain[pdTrain.columns[:-1]]
XTrain.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 399 entries, 0 to 398
Data columns (total 6 columns):
 #   Column      Non-Null Count  Dtype
---  ------      --------------  -----
 0   ucellsize   399 non-null    int64
 1   ucellshape  399 non-null    int64
 2   mgadhesion  399 non-null    int64
 3   sepics      399 non-null    int64
 4   normnucl    399 non-null    int64
 5   mitoses     399 non-null    int64
dtypes: int64(6)
memory usage: 18.8 KB
In [ ]:
#outil pour centrer et réduire les X
#avec StandardScaler
from sklearn.preprocessing import StandardScaler
sts = StandardScaler()

#données centrées et réduites
ZTrain = sts.fit_transform(XTrain)
print(ZTrain.shape)
(399, 6)
In [ ]:
#distribution des classes
pdTrain.classe.value_counts()
Out[ ]:
classe
begnin       267
malignant    132
Name: count, dtype: int64
In [ ]:
#recoder la cible en 0 : begnin ; 1 : malignant
yTrain = pdTrain.classe.map({'malignant':1,'begnin':0}).values

#comptage
import numpy
numpy.unique(yTrain,return_counts=True)
Out[ ]:
(array([0, 1], dtype=int64), array([267, 132], dtype=int64))

Rendre les données compatibles avec PyTorch

In [ ]:
#tranformer les variables Z en tensor
tensor_ZTrain = torch.FloatTensor(ZTrain)

#qui est d'un type particulier
print(type(tensor_ZTrain))
<class 'torch.Tensor'>
In [ ]:
#dimensions
print(tensor_ZTrain.shape)
torch.Size([399, 6])
In [ ]:
#valeurs
print(tensor_ZTrain)
tensor([[-0.6800, -0.7377, -0.5984, -0.5140, -0.5897, -0.3145],
        [-0.0134, -0.4001, -0.5984, -0.0531,  1.0487, -0.3145],
        [ 1.9865,  1.2875,  0.1164,  0.4077,  1.3764, -0.3145],
        ...,
        [-0.6800, -0.7377, -0.5984, -0.5140, -0.5897, -0.3145],
        [-0.6800, -0.7377, -0.5984, -0.5140, -0.2620, -0.3145],
        [ 0.3200,  0.9500, -0.5984, -0.5140,  0.0657, -0.3145]])
In [ ]:
#idem pour y
tensor_yTrain = torch.FloatTensor(yTrain)
print(tensor_yTrain)
tensor([0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0.,
        0., 0., 0., 1., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0.,
        1., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0.,
        1., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0.,
        0., 1., 1., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 1.,
        0., 0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.,
        1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,
        1., 0., 1., 1., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0.,
        1., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 1., 1., 0., 1., 0., 0.,
        1., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1.,
        0., 1., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.,
        1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.,
        1., 1., 0., 1., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
        1., 1., 1., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        1., 1., 1., 1., 0., 0., 1., 0., 1., 1., 0., 1., 1., 0., 0., 0., 1., 1.,
        0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 1., 1., 1.,
        0., 0., 1.])
In [ ]:
#dimension
print(tensor_yTrain.shape)
torch.Size([399])

Perceptron simple avec PyTorch¶

Elaboration de la structure¶

Structure du réseau de neurones

In [ ]:
#classe de calcul - Perceptron simple
#héritier de torch.nn.Module
class MyPS(torch.nn.Module):
    #constructeur - liste des éléments
    #qui composent le réseau
    def __init__(self,p):
        #appel du constructeur de l'ancêtre
        super(MyPS,self).__init__()
        #couche d'entrée (p variables) vers sortie (1 neurone)
        self.layer1 = torch.nn.Linear(p,1)
        #fonction de transfert sigmoïde à la sortie
        self.ft1 = torch.nn.Sigmoid()
        
    #structuration des éléments
    #calcul de la sortie du réseau
    #à partir d'une matrice x en entrée
    def forward(self,x):
        #application de la combinaison linéaire
        comb_lin = self.layer1(x)
        #transformation avec la fonction de transfert
        proba = self.ft1(comb_lin)
        return proba

Choix de la fonction de perte à optimiser

In [ ]:
#fonction critère à optimiser
critere_ps = torch.nn.MSELoss()

Instanciation du modèle

In [ ]:
#instanciation du modele
#.shape[1] pour le nombre de descripteurs => 6
ps = MyPS(tensor_ZTrain.shape[1])

Choix de l'algorithme d'optimisation - Gradient stochastique

In [ ]:
#algorithme d'optimisation
#on lui passe les paramètres à manipuler
optimiseur_ps = torch.optim.Adam(ps.parameters())

Quelques vérifications - Coefficients et sortie du réseau

In [ ]:
#poids synaptiques
#initialisés aléatoirement
print(ps.layer1.weight)
Parameter containing:
tensor([[ 0.2754, -0.3938, -0.3197, -0.2911,  0.2502, -0.3007]],
       requires_grad=True)
In [ ]:
#et l'intercept
print(ps.layer1.bias)
Parameter containing:
tensor([-0.3651], requires_grad=True)
In [ ]:
#calculer des sorties du réseau avec ces poids initiaux (aléatoires)
#équivalent à yPred = ps(tensor_ZTrain)
#on a des probas d'appartenance ici
yPred = ps.forward(tensor_ZTrain)

#affichage des 10 premières valeurs
print(yPred[:10])
tensor([[0.5065],
        [0.5872],
        [0.4895],
        [0.1170],
        [0.1574],
        [0.5065],
        [0.5270],
        [0.5065],
        [0.5065],
        [0.5294]], grad_fn=<SliceBackward0>)
In [ ]:
#format de yPred - matrice en réalité
print(yPred.shape)
torch.Size([399, 1])
In [ ]:
#pour transformer en vecteur
print(yPred.squeeze()[:10])
tensor([0.5065, 0.5872, 0.4895, 0.1170, 0.1574, 0.5065, 0.5270, 0.5065, 0.5065,
        0.5294], grad_fn=<SliceBackward0>)
In [ ]:
#en effet
print(yPred.squeeze().shape)
torch.Size([399])

Valeur de départ de la perte

In [ ]:
#MSE au départ (avec les poids initiaux aléatoires)
MSE1st = critere_ps(yPred.squeeze(),tensor_yTrain)
print(MSE1st)
tensor(0.3336, grad_fn=<MseLossBackward0>)
In [ ]:
#vérification en passant par les vecteurs numpy
#on a bien une fonction de perte MSE
numpy.mean((yPred.squeeze().detach().numpy()-tensor_yTrain.numpy())**2)
Out[ ]:
0.33355263

Entraînement du modèle¶

Bien détailler la structure du code avec les différentes étapes.

In [ ]:
#fonction pour apprentissage avec les paramètres :
#X, y, instance de classe torch, critère à optimiser, algo d'optimisation...
#...et n_epochs nombre de passage sur la base
def train_session(X,y,classifier,criterion,optimizer,n_epochs=10000):
    #vecteur pour collecter le loss au fil des itérations
    losses = numpy.zeros(n_epochs)
    #itérer (boucle) pour optimiser - n_epochs fois sur la base
    for iter in range(n_epochs):
        #réinitialiser (ràz) le gradient
        #nécessaire à chaque passage sinon PyTorch accumule
        optimizer.zero_grad()
        #calculer la sortie du réseau
        yPred = classifier.forward(X) #ou simplement classifier(X)
        #calculer la perte
        perte = criterion(yPred.squeeze(),y)
        #collecter la perte calculée dans le vecteur losses
        losses[iter] = perte.item()
        #calcul du gradient et retropropagation
        perte.backward()
        #màj des poids synaptiques
        optimizer.step()
    #sortie de la boucle
    #renvoie le vecteur avec les valeurs de perte à chaque epoch
    return losses

On peut lancer l'entraînement du modèle maintenant.

In [ ]:
#lancer l'apprentissage
#revenir sur les paramètres passés à la fonction
pertes = train_session(tensor_ZTrain,tensor_yTrain,ps,critere_ps,optimiseur_ps)
In [ ]:
#valeur de la perte au final
print(pertes[-1])
0.03577861934900284
In [ ]:
#courbe de décroissance de la perte
import matplotlib.pyplot as plt
plt.plot(numpy.arange(0,pertes.shape[0]),pertes)
plt.title("Evolution fnct de perte")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()
No description has been provided for this image
In [ ]:
#inspection des coefficients après apprentissage
print(ps.layer1.weight)
Parameter containing:
tensor([[1.9119, 2.8555, 0.1655, 1.8033, 1.0651, 0.8123]], requires_grad=True)
In [ ]:
#et de l'intercept
print(ps.layer1.bias)
Parameter containing:
tensor([0.2614], requires_grad=True)

Evaluation sur l'échantillon test¶

Préparation des données

In [ ]:
#chargement de l'échantillon test
pdTest = pandas.read_excel("breast_test.xlsx",header=0)
print(pdTest.info())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 300 entries, 0 to 299
Data columns (total 7 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   ucellsize   300 non-null    int64 
 1   ucellshape  300 non-null    int64 
 2   mgadhesion  300 non-null    int64 
 3   sepics      300 non-null    int64 
 4   normnucl    300 non-null    int64 
 5   mitoses     300 non-null    int64 
 6   classe      300 non-null    object
dtypes: int64(6), object(1)
memory usage: 16.5+ KB
None
In [ ]:
#données centrées et réduites (avec les moyennes et écarts-type de train)
ZTest = sts.transform(pdTest[pdTest.columns[:-1]])
print(ZTest)
[[-0.67999948 -0.06259939 -0.59835544 -0.97482724 -0.58966111 -0.31445173]
 [ 2.31985082  0.27492976  1.54605012 -0.05313039  2.3594657  -0.31445173]
 [-0.67999948 -0.73765769 -0.59835544 -0.51397882 -0.58966111 -0.31445173]
 ...
 [ 0.98658402 -0.06259939  1.1886492   0.40771803  1.70410419  1.57699912]
 [-0.67999948 -0.06259939 -0.59835544 -0.51397882 -0.58966111 -0.31445173]
 [-0.67999948 -0.73765769 -0.59835544 -0.51397882 -0.58966111 -0.31445173]]
In [ ]:
#mettre au format tensor pour PyTorch
tensor_ZTest = torch.FloatTensor(ZTest)
print(tensor_ZTest)
tensor([[-0.6800, -0.0626, -0.5984, -0.9748, -0.5897, -0.3145],
        [ 2.3199,  0.2749,  1.5461, -0.0531,  2.3595, -0.3145],
        [-0.6800, -0.7377, -0.5984, -0.5140, -0.5897, -0.3145],
        ...,
        [ 0.9866, -0.0626,  1.1886,  0.4077,  1.7041,  1.5770],
        [-0.6800, -0.0626, -0.5984, -0.5140, -0.5897, -0.3145],
        [-0.6800, -0.7377, -0.5984, -0.5140, -0.5897, -0.3145]])
In [ ]:
#recoder en 0/1 la cible
yTest = pdTest.classe.map({'malignant':1,'begnin':0}).values
yTest
Out[ ]:
array([0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
       0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0,
       1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0,
       0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0,
       1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0,
       1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1,
       1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0,
       0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0,
       0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0,
       0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
       1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0,
       0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1,
       1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0], dtype=int64)
In [ ]:
#prédiction de la probabilité d'appartenance
#en appliquant le modèle sur les descripteurs standardisés
proba_predTest = ps.forward(tensor_ZTest)
print(proba_predTest)
tensor([[0.0187],
        [0.9996],
        [0.0063],
        [0.0307],
        [0.0063],
        [0.0063],
        [1.0000],
        [0.0165],
        [0.0063],
        [1.0000],
        [0.0063],
        [0.6162],
        [0.0067],
        [0.0090],
        [1.0000],
        [0.0028],
        [0.0063],
        [0.0063],
        [0.9998],
        [0.0028],
        [1.0000],
        [0.0063],
        [0.0071],
        [0.0165],
        [1.0000],
        [0.9869],
        [0.0751],
        [0.0067],
        [0.9999],
        [0.9370],
        [0.0063],
        [0.9385],
        [0.0063],
        [1.0000],
        [0.0052],
        [0.7079],
        [0.9076],
        [0.0063],
        [0.0067],
        [1.0000],
        [0.9237],
        [0.0063],
        [0.0119],
        [0.8312],
        [1.0000],
        [0.0029],
        [0.0028],
        [1.0000],
        [0.0063],
        [0.0168],
        [0.0063],
        [0.5992],
        [0.0063],
        [0.9999],
        [1.0000],
        [0.0153],
        [0.9999],
        [0.0071],
        [0.0028],
        [1.0000],
        [0.0028],
        [0.1589],
        [0.0063],
        [0.0090],
        [0.0063],
        [0.0090],
        [0.1429],
        [0.9999],
        [0.1870],
        [0.0063],
        [0.0071],
        [0.9988],
        [0.0063],
        [0.0063],
        [1.0000],
        [0.0063],
        [0.9901],
        [0.0063],
        [1.0000],
        [1.0000],
        [0.0598],
        [0.0153],
        [1.0000],
        [0.0063],
        [0.0063],
        [0.0063],
        [0.4521],
        [0.0067],
        [0.9991],
        [0.2164],
        [0.9923],
        [0.0063],
        [0.0063],
        [0.1274],
        [0.9925],
        [1.0000],
        [0.0063],
        [1.0000],
        [0.9988],
        [0.0063],
        [0.0063],
        [0.5801],
        [0.9997],
        [0.0063],
        [0.7242],
        [0.0063],
        [0.0063],
        [0.0063],
        [0.0063],
        [0.0067],
        [0.9933],
        [0.0063],
        [0.9994],
        [0.0063],
        [0.9927],
        [0.0063],
        [1.0000],
        [0.0063],
        [0.0144],
        [0.0063],
        [1.0000],
        [1.0000],
        [0.9872],
        [0.0144],
        [0.9985],
        [1.0000],
        [0.0144],
        [0.0028],
        [1.0000],
        [0.0865],
        [1.0000],
        [0.9975],
        [0.6464],
        [0.0063],
        [0.0063],
        [0.9999],
        [0.0063],
        [0.1031],
        [0.0063],
        [0.0242],
        [0.0067],
        [0.9183],
        [0.0063],
        [0.9999],
        [1.0000],
        [1.0000],
        [0.0063],
        [0.9813],
        [0.1402],
        [0.0530],
        [1.0000],
        [0.0063],
        [0.9699],
        [0.0236],
        [0.0028],
        [0.6594],
        [0.0071],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [0.0144],
        [0.0189],
        [0.0063],
        [0.0071],
        [0.0028],
        [0.3480],
        [0.0119],
        [0.7270],
        [0.0063],
        [0.0717],
        [0.0144],
        [1.0000],
        [0.0063],
        [0.0063],
        [0.0028],
        [0.0063],
        [0.0063],
        [0.0307],
        [0.9901],
        [1.0000],
        [0.9721],
        [1.0000],
        [0.0076],
        [0.0063],
        [0.1819],
        [0.0063],
        [0.0766],
        [0.9999],
        [1.0000],
        [0.0063],
        [0.0420],
        [0.0472],
        [0.1208],
        [0.9996],
        [0.0420],
        [0.0067],
        [0.0063],
        [0.0067],
        [0.0063],
        [0.0165],
        [1.0000],
        [1.0000],
        [1.0000],
        [0.0090],
        [1.0000],
        [0.0105],
        [1.0000],
        [0.0063],
        [0.0028],
        [0.0029],
        [0.0165],
        [0.0307],
        [0.0028],
        [0.2397],
        [0.0063],
        [0.0063],
        [0.9990],
        [1.0000],
        [0.0063],
        [0.9969],
        [0.0031],
        [0.0063],
        [0.9999],
        [0.0028],
        [1.0000],
        [1.0000],
        [0.6649],
        [0.0667],
        [0.0063],
        [0.0067],
        [0.0063],
        [0.0063],
        [0.0090],
        [0.9999],
        [1.0000],
        [0.0165],
        [0.0028],
        [1.0000],
        [0.0063],
        [0.0420],
        [0.0063],
        [0.0063],
        [0.9930],
        [1.0000],
        [0.0063],
        [0.9993],
        [0.0063],
        [1.0000],
        [0.0063],
        [0.0063],
        [0.0165],
        [1.0000],
        [1.0000],
        [0.0063],
        [0.0063],
        [0.0063],
        [0.0063],
        [0.9933],
        [0.0071],
        [0.0420],
        [0.0063],
        [0.0325],
        [0.0063],
        [0.0063],
        [1.0000],
        [0.1966],
        [0.0063],
        [0.0353],
        [0.0090],
        [0.0063],
        [0.0063],
        [0.0063],
        [0.9995],
        [0.0223],
        [1.0000],
        [0.9987],
        [0.3653],
        [0.0063],
        [0.0063],
        [1.0000],
        [0.3619],
        [0.0063],
        [0.9999],
        [0.9619],
        [0.9708],
        [0.9999],
        [1.0000],
        [0.9225],
        [0.6776],
        [0.0063],
        [0.9915],
        [0.1909],
        [0.0193],
        [1.0000],
        [0.0031],
        [0.0165],
        [0.9975],
        [0.0420],
        [0.0063]], grad_fn=<SigmoidBackward0>)
In [ ]:
#on a une matrice "torch"
proba_predTest.shape
Out[ ]:
torch.Size([300, 1])
In [ ]:
#que l'on transforme en vecteur numpy
vec_predTest = proba_predTest.detach().squeeze()
vec_predTest.shape
Out[ ]:
torch.Size([300])
In [ ]:
#que l'on peut convertir en classe d'appartenance
#en comparant à la valeur 0.5
classe_predTest = numpy.where(vec_predTest > 0.5,1.0,0.0)

#on a un vecteur numpy ici
print(classe_predTest)
[0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 0. 1. 0. 0. 0. 1. 0. 1. 0. 0. 0.
 1. 1. 0. 0. 1. 1. 0. 1. 0. 1. 0. 1. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 0. 1.
 0. 0. 0. 1. 0. 1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1.
 0. 0. 1. 0. 1. 0. 1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 1. 1.
 0. 1. 1. 0. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 1. 0. 0. 0.
 1. 1. 1. 0. 1. 1. 0. 0. 1. 0. 1. 1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 1.
 1. 1. 0. 1. 0. 0. 1. 0. 1. 0. 0. 1. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0.
 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0.
 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 1. 1. 0. 1. 0. 0. 1. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0.
 0. 0. 0. 1. 1. 0. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.
 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 1. 0. 0. 0. 1. 0. 0. 1. 1. 1. 1. 1.
 1. 1. 0. 1. 0. 0. 1. 0. 0. 1. 0. 0.]
In [ ]:
#calcul de l'accuracy
print(numpy.mean(classe_predTest == yTest))
0.9666666666666667

Perceptron multicouche¶

Définir la nouvelle structure du réseau.

In [ ]:
#perceptron multicouche
class MyPMC(torch.nn.Module):
    #constructeur
    def __init__(self,p):
        #appel du constructeur de l'ancêtre
        super(MyPMC,self).__init__()
        #couche d'entrée (p variables) vers intermédiaire (2 neurones)
        self.layer1 = torch.nn.Linear(p,2)
        #fonction de transfert
        self.ft1 = torch.nn.Sigmoid()
        #couche intermédiaire vers sortie (1 neurone)
        self.layer2 = torch.nn.Linear(2,1)
        #fonction de transfert sigmoïde
        self.ft2 = torch.nn.Sigmoid()
        
    #premier forward
    #couche entrée -> couche cachée
    def forward_1(self,x):
        #application de la combinaison linéaire
        comb_lin_1 = self.layer1(x)
        #appliquer la fonction sigmoïde
        return self.ft1(comb_lin_1)
    
    #second forward
    #couche cachée -> couche sortie
    def forward_2(self,x_prim):
        #puis seconde combinaison linéaire
        comb_lin_2 = self.layer2(x_prim)
        #appliquer la transformation sigmoïde
        return self.ft2(comb_lin_2)


    #calcul de la sortie du réseau
    #à partir d'une matrice x en entrée
    def forward(self,x):        
        #premier forward
        out_1 = self.forward_1(x)
        #second forward
        out_2 = self.forward_2(out_1)
        #return
        return out_2

Entraînement du modèle

In [ ]:
#instanciation
pmc = MyPMC(tensor_ZTrain.shape[1])
In [ ]:
#défintion des outils d'apprentissage
#critere - perte quadratique
critere_pmc = torch.nn.MSELoss()

#optimiseur
optimiseur_pmc = torch.optim.Adam(pmc.parameters())
In [ ]:
#lancement de l'entraînement
#en ré-explitant la fonction ci-dessus, seuls les paramètres changent
#dont l'instance de la structure du réseau "pmc"
pertes = train_session(tensor_ZTrain,tensor_yTrain,pmc,critere_pmc,optimiseur_pmc)
In [ ]:
#évolution de la perte vs. epoch
plt.plot(numpy.arange(0,pertes.shape[0]),pertes)
plt.title("Evolution fnct de perte")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()
No description has been provided for this image

Inspection des poids synaptiques (coefficients)

In [ ]:
#poids de entrée -> couche cachée
print(pmc.layer1.weight)
Parameter containing:
tensor([[-1.4811, -6.1619,  0.2026, -1.7708, -1.0070, -0.6883],
        [ 0.9091,  3.9800,  0.9036,  1.8425,  2.1994,  1.4802]],
       requires_grad=True)
In [ ]:
#intercept corresp.
print(pmc.layer1.bias)
Parameter containing:
tensor([-1.1908,  1.2490], requires_grad=True)
In [ ]:
#poids de couche cachée -> couche de sortie
print(pmc.layer2.weight)
Parameter containing:
tensor([[-3.6089,  3.8023]], requires_grad=True)
In [ ]:
#intercept corresp.
print(pmc.layer2.bias)
Parameter containing:
tensor([-1.0974], requires_grad=True)

Evaluation sur l'échantillon test

In [ ]:
#prédiction proba d'appartenance
proba_pmc = pmc.forward(tensor_ZTest)
print(proba_pmc)
tensor([[0.0125],
        [0.9372],
        [0.0092],
        [0.0108],
        [0.0092],
        [0.0092],
        [0.9373],
        [0.0102],
        [0.0092],
        [0.9373],
        [0.0092],
        [0.6626],
        [0.0092],
        [0.0093],
        [0.9373],
        [0.0091],
        [0.0092],
        [0.0092],
        [0.9373],
        [0.0091],
        [0.9373],
        [0.0092],
        [0.0093],
        [0.0102],
        [0.9373],
        [0.9366],
        [0.0137],
        [0.0092],
        [0.9373],
        [0.9318],
        [0.0092],
        [0.9146],
        [0.0092],
        [0.9373],
        [0.0091],
        [0.8038],
        [0.9300],
        [0.0092],
        [0.0092],
        [0.9373],
        [0.8562],
        [0.0092],
        [0.0093],
        [0.6329],
        [0.9373],
        [0.0091],
        [0.0091],
        [0.9373],
        [0.0092],
        [0.0095],
        [0.0092],
        [0.7078],
        [0.0092],
        [0.9373],
        [0.9373],
        [0.0096],
        [0.9373],
        [0.0093],
        [0.0091],
        [0.9373],
        [0.0091],
        [0.0191],
        [0.0092],
        [0.0093],
        [0.0092],
        [0.0093],
        [0.0219],
        [0.9373],
        [0.1469],
        [0.0092],
        [0.0093],
        [0.9373],
        [0.0092],
        [0.0092],
        [0.9373],
        [0.0092],
        [0.9366],
        [0.0092],
        [0.9373],
        [0.9373],
        [0.0150],
        [0.0096],
        [0.9373],
        [0.0092],
        [0.0092],
        [0.0092],
        [0.4909],
        [0.0092],
        [0.9373],
        [0.0159],
        [0.9368],
        [0.0092],
        [0.0092],
        [0.0134],
        [0.9347],
        [0.9373],
        [0.0092],
        [0.9373],
        [0.9373],
        [0.0092],
        [0.0092],
        [0.8155],
        [0.9370],
        [0.0092],
        [0.8813],
        [0.0092],
        [0.0092],
        [0.0092],
        [0.0092],
        [0.0092],
        [0.9372],
        [0.0092],
        [0.9373],
        [0.0092],
        [0.9370],
        [0.0092],
        [0.9373],
        [0.0092],
        [0.0095],
        [0.0092],
        [0.9373],
        [0.9373],
        [0.9364],
        [0.0095],
        [0.9373],
        [0.9373],
        [0.0095],
        [0.0091],
        [0.9373],
        [0.0748],
        [0.9373],
        [0.9365],
        [0.5102],
        [0.0092],
        [0.0092],
        [0.9372],
        [0.0092],
        [0.1221],
        [0.0092],
        [0.0109],
        [0.0092],
        [0.9317],
        [0.0092],
        [0.9373],
        [0.9373],
        [0.9373],
        [0.0092],
        [0.9261],
        [0.2326],
        [0.0135],
        [0.9373],
        [0.0092],
        [0.9337],
        [0.0095],
        [0.0091],
        [0.7144],
        [0.0093],
        [0.9373],
        [0.9373],
        [0.9373],
        [0.9373],
        [0.0095],
        [0.0098],
        [0.0092],
        [0.0093],
        [0.0091],
        [0.4352],
        [0.0093],
        [0.7726],
        [0.0092],
        [0.0119],
        [0.0095],
        [0.9373],
        [0.0092],
        [0.0092],
        [0.0091],
        [0.0092],
        [0.0092],
        [0.0108],
        [0.9361],
        [0.9373],
        [0.9360],
        [0.9373],
        [0.0094],
        [0.0092],
        [0.0654],
        [0.0092],
        [0.0249],
        [0.9373],
        [0.9373],
        [0.0092],
        [0.0182],
        [0.0159],
        [0.0167],
        [0.9373],
        [0.0182],
        [0.0092],
        [0.0092],
        [0.0092],
        [0.0092],
        [0.0102],
        [0.9373],
        [0.9373],
        [0.9373],
        [0.0093],
        [0.9373],
        [0.0094],
        [0.9373],
        [0.0092],
        [0.0091],
        [0.0091],
        [0.0102],
        [0.0108],
        [0.0091],
        [0.1325],
        [0.0092],
        [0.0092],
        [0.9373],
        [0.9373],
        [0.0092],
        [0.9371],
        [0.0091],
        [0.0092],
        [0.9373],
        [0.0091],
        [0.9373],
        [0.9373],
        [0.9006],
        [0.0134],
        [0.0092],
        [0.0092],
        [0.0092],
        [0.0092],
        [0.0093],
        [0.9373],
        [0.9373],
        [0.0102],
        [0.0091],
        [0.9373],
        [0.0092],
        [0.0182],
        [0.0092],
        [0.0092],
        [0.9371],
        [0.9373],
        [0.0092],
        [0.9372],
        [0.0092],
        [0.9373],
        [0.0092],
        [0.0092],
        [0.0102],
        [0.9373],
        [0.9373],
        [0.0092],
        [0.0092],
        [0.0092],
        [0.0092],
        [0.9366],
        [0.0093],
        [0.0182],
        [0.0092],
        [0.0101],
        [0.0092],
        [0.0092],
        [0.9373],
        [0.2718],
        [0.0092],
        [0.0150],
        [0.0093],
        [0.0092],
        [0.0092],
        [0.0092],
        [0.9372],
        [0.0094],
        [0.9373],
        [0.9373],
        [0.1605],
        [0.0092],
        [0.0092],
        [0.9373],
        [0.6371],
        [0.0092],
        [0.9373],
        [0.9279],
        [0.9366],
        [0.9373],
        [0.9373],
        [0.9284],
        [0.8799],
        [0.0092],
        [0.9371],
        [0.0742],
        [0.0107],
        [0.9373],
        [0.0091],
        [0.0102],
        [0.9365],
        [0.0182],
        [0.0092]], grad_fn=<SigmoidBackward0>)
In [ ]:
#conversion en classe d'appartenance
classe_pmc = numpy.where(proba_pmc.detach().squeeze() > 0.5,1.0,0.0)
print(classe_pmc)
[0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 0. 1. 0. 0. 0. 1. 0. 1. 0. 0. 0.
 1. 1. 0. 0. 1. 1. 0. 1. 0. 1. 0. 1. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 0. 1.
 0. 0. 0. 1. 0. 1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1.
 0. 0. 1. 0. 1. 0. 1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 1. 1.
 0. 1. 1. 0. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 1. 0. 0. 0.
 1. 1. 1. 0. 1. 1. 0. 0. 1. 0. 1. 1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 1.
 1. 1. 0. 1. 0. 0. 1. 0. 1. 0. 0. 1. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0.
 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0.
 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 1. 1. 0. 1. 0. 0. 1. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0.
 0. 0. 0. 1. 1. 0. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.
 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 1. 0. 0. 0. 1. 1. 0. 1. 1. 1. 1. 1.
 1. 1. 0. 1. 0. 0. 1. 0. 0. 1. 0. 0.]
In [ ]:
#accuracy
print(numpy.mean(classe_predTest == yTest))
0.9666666666666667

Représentation intermédiaire de la couche cachée

In [ ]:
#obtenir la sortie de la couche cachée
#renvoyée par forward_1
hidden = pmc.forward_1(tensor_ZTrain)
print(hidden)
tensor([[9.9743e-01, 3.8527e-03],
        [6.0534e-01, 6.9992e-01],
        [8.8778e-07, 9.9999e-01],
        [5.4958e-11, 1.0000e+00],
        [1.0575e-08, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [9.9643e-01, 7.8887e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9579e-01, 5.2092e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [1.1559e-03, 9.6520e-01],
        [5.9068e-11, 1.0000e+00],
        [2.1572e-04, 9.9997e-01],
        [9.9743e-01, 3.8527e-03],
        [1.7370e-03, 9.9939e-01],
        [8.8871e-01, 5.9714e-02],
        [1.3007e-01, 9.9032e-01],
        [1.3884e-05, 1.0000e+00],
        [8.5830e-01, 5.3738e-02],
        [9.4992e-11, 1.0000e+00],
        [1.4907e-05, 9.9999e-01],
        [9.9743e-01, 3.8527e-03],
        [9.8117e-01, 2.0059e-02],
        [2.5587e-11, 1.0000e+00],
        [2.5667e-12, 1.0000e+00],
        [1.0692e-09, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [9.9886e-01, 1.6518e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9886e-01, 1.6518e-03],
        [4.2081e-03, 9.9446e-01],
        [9.9743e-01, 3.8527e-03],
        [8.5830e-01, 5.3738e-02],
        [9.7352e-01, 5.0215e-02],
        [4.3081e-01, 1.7872e-01],
        [9.9643e-01, 7.8887e-03],
        [9.9886e-01, 1.6518e-03],
        [3.1143e-04, 9.9760e-01],
        [4.0208e-08, 9.9993e-01],
        [9.9313e-01, 7.0398e-03],
        [9.9743e-01, 3.8527e-03],
        [7.8711e-01, 7.1399e-02],
        [9.9743e-01, 3.8527e-03],
        [1.8419e-06, 9.9999e-01],
        [1.4090e-04, 9.9957e-01],
        [5.2251e-02, 6.8551e-01],
        [4.9220e-01, 6.3312e-01],
        [9.9777e-01, 7.3242e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9886e-01, 1.6518e-03],
        [9.9743e-01, 3.8527e-03],
        [9.7979e-01, 1.4604e-02],
        [9.7158e-11, 1.0000e+00],
        [2.4531e-06, 9.9988e-01],
        [3.8992e-04, 9.9891e-01],
        [8.5830e-01, 5.3738e-02],
        [9.9886e-01, 1.6518e-03],
        [1.5395e-05, 9.9982e-01],
        [9.8117e-01, 2.0059e-02],
        [9.9886e-01, 1.6518e-03],
        [5.1304e-03, 9.9315e-01],
        [9.5543e-01, 3.3484e-02],
        [9.5811e-01, 5.3906e-02],
        [9.7979e-01, 1.4604e-02],
        [9.9777e-01, 7.3242e-03],
        [9.9311e-01, 3.2516e-02],
        [1.0316e-01, 9.8120e-01],
        [1.6145e-11, 1.0000e+00],
        [2.6477e-03, 9.9993e-01],
        [9.9777e-01, 7.3242e-03],
        [2.5092e-09, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [2.4155e-11, 1.0000e+00],
        [3.5294e-06, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [9.9643e-01, 7.8887e-03],
        [9.9743e-01, 3.8527e-03],
        [1.5660e-04, 9.9999e-01],
        [5.6243e-08, 9.9999e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [1.0670e-03, 9.9988e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [1.8448e-12, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [2.7865e-03, 9.9647e-01],
        [2.9886e-07, 1.0000e+00],
        [1.1033e-10, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [1.4845e-04, 9.9648e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9579e-01, 5.2092e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9901e-01, 3.1464e-03],
        [2.4951e-10, 1.0000e+00],
        [8.5830e-01, 5.3738e-02],
        [9.9743e-01, 3.8527e-03],
        [9.7979e-01, 1.4604e-02],
        [9.9886e-01, 1.6518e-03],
        [6.3078e-05, 1.0000e+00],
        [9.4788e-01, 5.9701e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [8.5830e-01, 5.3738e-02],
        [2.0504e-09, 1.0000e+00],
        [1.8329e-08, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [2.3245e-01, 9.9890e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9814e-01, 2.2351e-03],
        [9.9743e-01, 3.8527e-03],
        [3.8827e-09, 1.0000e+00],
        [1.3917e-02, 7.9466e-01],
        [1.8448e-12, 1.0000e+00],
        [9.3678e-01, 1.2654e-01],
        [8.6688e-01, 7.2733e-02],
        [9.9743e-01, 3.8527e-03],
        [1.2383e-01, 6.9659e-01],
        [9.9743e-01, 3.8527e-03],
        [7.5396e-01, 8.5081e-01],
        [9.9743e-01, 3.8527e-03],
        [1.3128e-10, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [9.5502e-02, 7.4431e-01],
        [9.9743e-01, 3.8527e-03],
        [6.9585e-05, 9.9997e-01],
        [9.9743e-01, 3.8527e-03],
        [1.1055e-05, 9.9998e-01],
        [9.9886e-01, 1.6518e-03],
        [9.9743e-01, 3.8527e-03],
        [7.8349e-01, 3.1409e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9780e-01, 6.9446e-03],
        [9.3485e-01, 5.2395e-01],
        [8.8307e-04, 9.8842e-01],
        [9.9743e-01, 3.8527e-03],
        [1.0994e-02, 8.9741e-01],
        [9.9743e-01, 3.8527e-03],
        [9.7979e-01, 1.4604e-02],
        [2.9762e-05, 9.9999e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9886e-01, 1.6518e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [4.3081e-01, 1.7872e-01],
        [7.2814e-01, 1.1719e-01],
        [9.9743e-01, 3.8527e-03],
        [3.6993e-03, 9.9986e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [3.8542e-04, 9.9996e-01],
        [9.9743e-01, 3.8527e-03],
        [1.4371e-11, 1.0000e+00],
        [4.3835e-03, 8.5417e-01],
        [2.0905e-11, 1.0000e+00],
        [2.0697e-07, 1.0000e+00],
        [9.9596e-01, 4.3470e-02],
        [1.5809e-03, 9.9922e-01],
        [2.1112e-07, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.5102e-01, 3.6166e-02],
        [4.0420e-05, 9.9975e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9901e-01, 3.1464e-03],
        [9.9761e-01, 5.3136e-03],
        [9.8629e-01, 1.1717e-02],
        [3.9986e-07, 9.9989e-01],
        [9.9420e-01, 8.9601e-03],
        [9.9743e-01, 3.8527e-03],
        [2.7632e-06, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [9.5042e-01, 4.8544e-02],
        [8.7502e-01, 9.7747e-02],
        [9.9743e-01, 3.8527e-03],
        [1.1376e-03, 9.9864e-01],
        [9.6731e-01, 1.9671e-02],
        [1.1008e-07, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [1.2752e-10, 1.0000e+00],
        [8.7909e-01, 2.1496e-01],
        [9.9743e-01, 3.8527e-03],
        [4.6717e-06, 9.9999e-01],
        [9.5375e-02, 9.3737e-01],
        [9.4753e-01, 2.6449e-02],
        [3.0238e-01, 7.9268e-01],
        [9.9313e-01, 7.0398e-03],
        [8.5830e-01, 5.3738e-02],
        [3.2934e-11, 1.0000e+00],
        [4.8336e-06, 9.9999e-01],
        [9.9743e-01, 3.8527e-03],
        [9.1229e-06, 9.9998e-01],
        [9.9886e-01, 1.6518e-03],
        [8.8721e-06, 1.0000e+00],
        [8.8871e-01, 5.9714e-02],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.8438e-01, 7.6545e-02],
        [6.9294e-01, 9.4286e-02],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [4.1587e-01, 3.7333e-01],
        [6.3047e-12, 1.0000e+00],
        [4.9422e-09, 1.0000e+00],
        [8.8690e-07, 1.0000e+00],
        [9.9777e-01, 7.3242e-03],
        [3.9736e-07, 9.9999e-01],
        [1.6319e-07, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [8.2265e-08, 1.0000e+00],
        [2.4070e-12, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [9.9579e-01, 5.2092e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9604e-01, 9.7384e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.7979e-01, 1.4604e-02],
        [9.9636e-01, 9.8907e-03],
        [9.9886e-01, 1.6518e-03],
        [5.7290e-02, 9.9511e-01],
        [1.4279e-01, 7.7239e-01],
        [9.7159e-01, 3.6868e-02],
        [9.0668e-02, 8.5570e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.7979e-01, 1.4604e-02],
        [9.9743e-01, 3.8527e-03],
        [2.1339e-11, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [8.5830e-01, 5.3738e-02],
        [9.9504e-01, 1.6084e-02],
        [1.8279e-06, 9.9999e-01],
        [2.2206e-12, 1.0000e+00],
        [9.9196e-01, 1.8248e-02],
        [8.5830e-01, 5.3738e-02],
        [3.9365e-05, 9.9983e-01],
        [5.5719e-02, 9.9621e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9196e-01, 1.8248e-02],
        [9.9743e-01, 3.8527e-03],
        [3.1268e-08, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [9.9886e-01, 1.6518e-03],
        [9.8528e-01, 8.5107e-03],
        [9.9777e-01, 7.3242e-03],
        [9.9743e-01, 3.8527e-03],
        [8.6688e-01, 7.2733e-02],
        [9.9504e-01, 1.6084e-02],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [7.8711e-01, 7.1399e-02],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9886e-01, 1.6518e-03],
        [1.2383e-01, 6.9659e-01],
        [9.7979e-01, 1.4604e-02],
        [9.9886e-01, 1.6518e-03],
        [9.9743e-01, 3.8527e-03],
        [9.7159e-01, 3.6868e-02],
        [9.9743e-01, 3.8527e-03],
        [9.7979e-01, 1.4604e-02],
        [9.9743e-01, 3.8527e-03],
        [9.9886e-01, 1.6518e-03],
        [9.9886e-01, 1.6518e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.7979e-01, 1.4604e-02],
        [9.9761e-01, 5.3136e-03],
        [8.1750e-03, 9.9927e-01],
        [8.1695e-02, 6.9538e-01],
        [6.3666e-03, 9.6020e-01],
        [1.6474e-02, 8.3151e-01],
        [9.8112e-14, 1.0000e+00],
        [2.0200e-03, 9.9976e-01],
        [9.9743e-01, 3.8527e-03],
        [9.8112e-14, 1.0000e+00],
        [9.9420e-01, 8.9601e-03],
        [9.9886e-01, 1.6518e-03],
        [3.2884e-09, 1.0000e+00],
        [8.1012e-05, 1.0000e+00],
        [4.3833e-03, 9.5896e-01],
        [6.6876e-02, 9.9631e-01],
        [4.3081e-01, 1.7872e-01],
        [9.9777e-01, 7.3242e-03],
        [9.9743e-01, 3.8527e-03],
        [9.7979e-01, 1.4604e-02],
        [9.9777e-01, 7.3242e-03],
        [4.8033e-05, 1.0000e+00],
        [9.9609e-01, 7.1806e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.7979e-01, 1.4604e-02],
        [8.5830e-01, 5.3738e-02],
        [9.9743e-01, 3.8527e-03],
        [9.9579e-01, 5.2092e-03],
        [9.9743e-01, 3.8527e-03],
        [9.5811e-01, 5.3906e-02],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [5.1756e-01, 2.5156e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [3.1799e-04, 9.9978e-01],
        [9.9743e-01, 3.8527e-03],
        [6.8667e-06, 9.9998e-01],
        [4.5139e-06, 9.9985e-01],
        [7.0941e-04, 9.9335e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9643e-01, 7.8887e-03],
        [9.9743e-01, 3.8527e-03],
        [9.6731e-01, 1.9671e-02],
        [5.9068e-11, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [4.4084e-12, 1.0000e+00],
        [9.9886e-01, 1.6518e-03],
        [9.8247e-01, 2.7495e-02],
        [9.9777e-01, 7.3242e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [9.7979e-01, 1.4604e-02],
        [2.5958e-06, 1.0000e+00],
        [5.2832e-06, 1.0000e+00],
        [9.8268e-03, 9.9488e-01],
        [9.9743e-01, 3.8527e-03],
        [3.2545e-05, 9.9996e-01],
        [5.5575e-06, 1.0000e+00],
        [8.6688e-01, 7.2733e-02],
        [9.5811e-01, 5.3906e-02],
        [9.9886e-01, 1.6518e-03],
        [9.8459e-01, 1.6303e-02],
        [9.8587e-01, 3.8178e-01],
        [2.7475e-03, 9.9798e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [5.3561e-01, 3.1705e-01],
        [9.9761e-01, 5.3136e-03],
        [8.7230e-03, 9.9683e-01],
        [9.9604e-01, 9.7384e-03],
        [2.5587e-11, 1.0000e+00],
        [1.2002e-06, 9.9999e-01],
        [6.5977e-05, 9.9962e-01],
        [2.9361e-12, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [1.5040e-04, 9.9995e-01],
        [9.9743e-01, 3.8527e-03],
        [1.9818e-02, 9.9729e-01],
        [2.2006e-11, 1.0000e+00],
        [9.7211e-01, 2.9568e-02],
        [1.7799e-01, 8.6928e-01],
        [4.5214e-13, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [9.9743e-01, 3.8527e-03],
        [7.8711e-01, 7.1399e-02],
        [2.4778e-03, 9.8971e-01],
        [2.0814e-08, 1.0000e+00],
        [9.9743e-01, 3.8527e-03],
        [9.9047e-01, 1.4366e-02],
        [9.9643e-01, 7.8887e-03],
        [9.9743e-01, 3.8527e-03],
        [1.7119e-06, 1.0000e+00],
        [9.9420e-01, 8.9601e-03],
        [9.9743e-01, 3.8527e-03],
        [2.3623e-06, 9.9992e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9886e-01, 1.6518e-03],
        [1.4523e-08, 1.0000e+00],
        [3.5438e-05, 9.9985e-01],
        [5.5212e-09, 1.0000e+00],
        [9.9504e-01, 1.6084e-02],
        [9.9743e-01, 3.8527e-03],
        [8.7814e-09, 1.0000e+00],
        [4.1517e-09, 1.0000e+00],
        [1.1503e-04, 9.9998e-01],
        [9.9743e-01, 3.8527e-03],
        [9.9643e-01, 7.8887e-03],
        [1.3871e-03, 9.7103e-01]], grad_fn=<SigmoidBackward0>)
In [ ]:
#sous la forme d'un data frame pandas
p_hidden = pandas.DataFrame(hidden.detach().numpy(),columns=['F1','F2'])
p_hidden.head()
Out[ ]:
F1 F2
0 9.974290e-01 0.003853
1 6.053422e-01 0.699923
2 8.877845e-07 0.999991
3 5.495754e-11 1.000000
4 1.057500e-08 0.999999
In [ ]:
#associer la classe d'apparenance au data frame
p_hidden['classe'] = pdTrain.classe
p_hidden.head()
Out[ ]:
F1 F2 classe
0 9.974290e-01 0.003853 begnin
1 6.053422e-01 0.699923 begnin
2 8.877845e-07 0.999991 malignant
3 5.495754e-11 1.000000 malignant
4 1.057500e-08 0.999999 malignant
In [ ]:
#affichage des points dans le plan "factoriel"
import seaborn as sns
sns.scatterplot(data=p_hidden,x='F1',y='F2',hue='classe')
Out[ ]:
<Axes: xlabel='F1', ylabel='F2'>
No description has been provided for this image
In [ ]:
#rappel de la droite de séparation : couche cachée -> sortie
print(f"Coefficients : {pmc.layer2.weight}")
print(f"Intercept : {pmc.layer2.bias}")
Coefficients : Parameter containing:
tensor([[-3.6089,  3.8023]], requires_grad=True)
Intercept : Parameter containing:
tensor([-1.0974], requires_grad=True)
In [ ]:
#récupération des coefficients et de l'intercept
coef = pmc.layer2.weight.detach().numpy()[0]
intercept = pmc.layer2.bias.detach().numpy()
print(coef)
print(intercept)
[-3.60885    3.8022952]
[-1.0974166]
In [ ]:
#coordoonées de F2 quand F1 = 0
f2_0 = -intercept/coef[1]
print(f2_0)
[0.28861952]
In [ ]:
#coordoonées de F2 quand F1 = 1
f2_1 = (-intercept - coef[0] * 1)/coef[1]
print(f2_1)
[1.2377436]
In [ ]:
#droite de séparation dans l'espace intermédiaire
plt.plot(numpy.array([0,1]),numpy.array([f2_0,f2_1]),'k-')
sns.scatterplot(data=p_hidden,x='F1',y='F2',hue='classe')
plt.show()
No description has been provided for this image