Version¶
In [1]:
import sentence_transformers
sentence_transformers.__version__
Out[1]:
'5.1.2'
Chargement du modèle pré-entraîné - Vérification¶
In [2]:
# importation de la fonction
from sentence_transformers import SentenceTransformer
# chargement du modèle pré-entraîné
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
In [3]:
# phrase
phrase = ['hello test embedding']
# essai encodage
embedding = model.encode(phrase)
# type structure
print(type(embedding))
#taille
print(embedding.shape)
<class 'numpy.ndarray'> (1, 384)
In [4]:
# voir un peu
embedding[:,:10]
Out[4]:
array([[-0.008262 , -0.02400577, 0.00214335, 0.00408584, 0.00906912,
0.01457942, -0.00903908, -0.00253245, -0.06765494, -0.03728982]],
dtype=float32)
Importation du corpus - Inpection, préparation¶
In [5]:
# changement de dossier
import os
os.chdir('C:/Users/ricco/Desktop/demo')
In [6]:
# chargement
import pandas
df = pandas.read_excel("reuters_r8.xlsx")
df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 7674 entries, 0 to 7673 Data columns (total 2 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 classe 7674 non-null object 1 texte 7674 non-null object dtypes: object(2) memory usage: 120.0+ KB
In [7]:
# premières lignes
df.head()
Out[7]:
| classe | texte | |
|---|---|---|
| 0 | trade | asian exporters fear damage from u s japan rif... |
| 1 | grain | china daily says vermin eat pct grain stocks a... |
| 2 | ship | australian foreign ship ban ends but nsw ports... |
| 3 | acq | sumitomo bank aims at quick recovery from merg... |
| 4 | earn | amatil proposes two for five bonus share issue... |
In [8]:
# distribution des classes
df.classe.value_counts()
Out[8]:
classe earn 3923 acq 2292 crude 374 trade 326 money-fx 293 interest 271 ship 144 grain 51 Name: count, dtype: int64
In [9]:
# partition train-test
from sklearn.model_selection import train_test_split
dfTrain, dfTest = train_test_split(df,train_size=0.7,stratify=df.classe,random_state=0)
print(dfTrain.shape)
print(dfTest.shape)
(5371, 2) (2303, 2)
Apprentissage¶
In [10]:
# matrice pour train
XTrain = model.encode(dfTrain.texte.to_list())
print(XTrain.shape)
(5371, 384)
In [11]:
# modélisation avec la régression logistique
from sklearn.linear_model import LogisticRegression
lr = LogisticRegression()
lr.fit(XTrain,dfTrain.classe)
Out[11]:
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| penalty | 'l2' | |
| dual | False | |
| tol | 0.0001 | |
| C | 1.0 | |
| fit_intercept | True | |
| intercept_scaling | 1 | |
| class_weight | None | |
| random_state | None | |
| solver | 'lbfgs' | |
| max_iter | 100 | |
| multi_class | 'deprecated' | |
| verbose | 0 | |
| warm_start | False | |
| n_jobs | None | |
| l1_ratio | None |
Evaluation¶
In [12]:
# encodage - matrice pour test
XTest = model.encode(dfTest.texte.to_list())
print(XTest.shape)
(2303, 384)
In [13]:
# prediction
pred = lr.predict(XTest)
# distribution des classes prédites
import numpy
numpy.unique(pred,return_counts=True)
Out[13]:
(array(['acq', 'crude', 'earn', 'grain', 'interest', 'money-fx', 'ship',
'trade'], dtype=object),
array([ 714, 103, 1172, 11, 76, 90, 35, 102]))
In [14]:
# matrice de confusion
from sklearn import metrics
metrics.confusion_matrix(dfTest.classe,pred)
Out[14]:
array([[ 677, 0, 10, 0, 0, 0, 0, 1],
[ 4, 101, 5, 0, 0, 0, 1, 1],
[ 22, 0, 1154, 0, 0, 1, 0, 0],
[ 0, 0, 0, 11, 0, 0, 0, 4],
[ 1, 0, 2, 0, 72, 7, 0, 0],
[ 2, 0, 0, 0, 4, 81, 0, 1],
[ 4, 2, 1, 0, 0, 0, 34, 2],
[ 4, 0, 0, 0, 0, 1, 0, 93]])
In [15]:
# accuracy
print(metrics.accuracy_score(dfTest.classe,pred))
0.9652627008250109