You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
42 lines
1000 B
42 lines
1000 B
'''
|
|
Author: SJ2050
|
|
Date: 2022-01-21 12:01:47
|
|
LastEditTime: 2022-01-22 11:04:40
|
|
Version: v0.0.1
|
|
Description: Full connection multi-layer perceptron using sklearn.
|
|
Copyright © 2022 SJ2050
|
|
'''
|
|
import numpy as np
|
|
from sklearn.neural_network import MLPClassifier
|
|
from sklearn.datasets import load_digits
|
|
from sklearn.metrics import confusion_matrix
|
|
from sklearn.metrics import accuracy_score
|
|
import matplotlib.pyplot as plt
|
|
|
|
# load data
|
|
digits = load_digits()
|
|
X = digits.data
|
|
Y = digits.target
|
|
X -= X.min()
|
|
X /= X.max()
|
|
|
|
x_train = X[:-500]
|
|
y_train = Y[:-500]
|
|
mlp = MLPClassifier(hidden_layer_sizes=(100), max_iter=10000)
|
|
mlp.fit(x_train,y_train)
|
|
|
|
x_test = X[-500:]
|
|
y_test = Y[-500:]
|
|
predictions = mlp.predict(x_test)
|
|
acc = accuracy_score(y_test, predictions)
|
|
print('--------------------------------')
|
|
print(f'predict: acc = {acc}.')
|
|
|
|
cm = confusion_matrix(y_test, predictions)
|
|
plt.matshow(cm)
|
|
plt.title(u'Confusion Matrix')
|
|
plt.colorbar()
|
|
plt.ylabel(u'Groundtruth')
|
|
plt.xlabel(u'Predict')
|
|
plt.show()
|