MLP网络可视化

目标:

  1. 学习神经网络在训练过程中,隐含层的特征变化机制

代码中的学习器为多层感知机(multi layer perception, mlp)

参考:

葛家驿,杨乃森,唐宏,徐朋磊,纪超.端到端的梯度提升网络分类过程可视化[J].信号处理,2022,38(02):355-366.DOI:10.16798/j.issn.1003-0530.2022.02.015.

研究准备

环境配置

[ ]:
'''
全连接网络分类过程可视化
'''
import numpy as np
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import layers
from tensorflow.keras.callbacks import CSVLogger
from sklearn.preprocessing import MinMaxScaler
from matplotlib.colors import ListedColormap
from matplotlib import cm
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_circles #同心圆数据
from sklearn.model_selection import train_test_split
import sys
sys.setrecursionlimit(500000)
import imageio
%pylab inline

Populating the interactive namespace from numpy and matplotlib

设置模拟数据

[ ]:
n_samples = 1000 #样本点数
X, y = make_circles(n_samples=1000,factor=.4,noise=.06,random_state=0) #生成同心圆数据
test_size = 0.5

#划分训练集、测试集
X_train, X_test, Y_train, Y_test = train_test_split(X, y, test_size=test_size, random_state=2)

c,r = np.mgrid[[slice(X.min()- .2,X.max() + .2,50j)]*2]
p = np.c_[c.flat,r.flat]

ss = StandardScaler().fit(X_train)
X = ss.transform(X)
p = ss.transform(p)
X_train = ss.transform(X_train)
X_test = ss.transform(X_test)

配置绘图环境

[ ]:
#设置画布大小和颜色
fig = plt.figure(figsize = (9,3))
top = cm.get_cmap('Oranges_r', 512)
bottom = cm.get_cmap('Blues', 512)
newcolors = np.vstack((top(np.linspace(0.55, 1, 512)),
                       bottom(np.linspace(0, 0.75, 512))))
cm_bright = ListedColormap(newcolors, name='OrangeBlue')

#训练数据可视化
plt.subplot(121)
m1 = plt.scatter(*X_train.T,c = Y_train,cmap = cm_bright,edgecolors='white',s = 20,linewidths = 0.5)
plt.title(f'train data ({int(n_samples*(1-test_size))} points)')
plt.axis('equal')

#测试数据可视化
plt.subplot(122)
m2 = plt.scatter(*X_test.T,c = Y_test,cmap = cm_bright,edgecolors='white',s = 20,linewidths = 0.5);
plt.title(f'test data ({int(n_samples*test_size)} points)')
plt.axis('equal')
ax = fig.get_axes()
plt.colorbar(ax = ax)
#plt.savefig(f'data_{n_samples}_points.png')
#plt.savefig(f'data_{n_samples}_points.pdf')
plt.show()

#全部数据可视化
fig = plt.figure(figsize = (7,6))
plt.scatter(*X.T,c = y,cmap = cm_bright,edgecolors='white',s = 20,linewidths = 0.5)
plt.title(f'Raw data ({n_samples} points)')
plt.axis('equal')
#plt.savefig(f'Raw data ({n_samples} points)')
#plt.savefig(f'Raw data ({n_samples} points).pdf')
plt.axis('equal')
#plt.colorbar(ax = ax)
plt.show()
../../_images/1stPart_Homework.1_MLP_7_0.png
../../_images/1stPart_Homework.1_MLP_7_1.png

配置数据与损失函数

[ ]:
num_classes=2 #设置类别数
y_train=keras.utils.to_categorical(Y_train,num_classes) #类别标签转换为onehot编码
y_test=keras.utils.to_categorical(Y_test,num_classes)
#定义损失曲线绘制函数

def plot_loss_accuracy(history, title_text, file_name):

    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Accuracy")
    ax2.set_ylabel("Loss")

    #ax1.set_ylim(-0.01,1.01)
    ax1.plot(history.epoch,
             history.history['accuracy'],
             label="Training Accuracy")
    ax1.plot(history.epoch,
             history.history['val_accuracy'],
             linestyle='--',
             label="Test Accuracy")

    #ax2.set_ylim(-0.01,1.01)
    ax2._get_lines.prop_cycler = ax1._get_lines.prop_cycler

    ax2.plot(history.epoch,
             history.history['loss'],
             label="Training Loss")
    ax2.plot(history.epoch,
             history.history['val_loss'],
             linestyle='--',
             label="Test Loss")

    ax1.legend()
    ax2.legend()
    plt.suptitle(title_text)
    plt.savefig(file_name)

配置全连接模型

[ ]:
#定义全连接层

def FullyConnected_layers(number_of_layers,
                         num_neurons_of_layer,
                         inputs):
    x = inputs
    for n in range(number_of_layers):
        x = layers.Dense(num_neurons_of_layer,
                         activation=tf.nn.relu,
                         name = f'{n}th-hidden')(x)
#         #可选加入线性变换层并对该层的输出进行可视化
#         x = layers.Dense(num_neurons_of_layer,
#                          name = f'{n}th-linear')(x)
    return x

#构建全连接网络
def FullyConnected_Model(number_of_layers,num_neurons_of_layer):
    inputs = keras.Input(shape=(2, ))
    x = FullyConnected_layers(number_of_layers=number_of_layers,num_neurons_of_layer=num_neurons_of_layer,inputs=inputs)
    #outputs = layers.Activation('softmax')(x)
    outputs = layers.Dense(2, activation='softmax',name = 'activation')(x)

    Fully_Model = keras.Model(inputs, outputs)
    Fully_Model.compile(loss=keras.losses.categorical_crossentropy,
                      optimizer=keras.optimizers.Adam(lr=3e-4),
                      metrics=['accuracy'])
    #csv_logger = CSVLogger(f'training_fully_model.log')
    history = Fully_Model.fit(X_train,
                        y_train,
                        batch_size=batch_size,
                        epochs=epochs,
                        verbose=2,
                        validation_data=(X_test, y_test))
    return Fully_Model,history

进行实验

设置训练参数

[ ]:
#设置参数

epochs=100 #迭代次数
batch_size=32 #batchsize
number_of_layers=4 #隐层数
num_neurons_of_layer=2 #隐层的神经元数
Fully_Model,history=FullyConnected_Model(number_of_layers,num_neurons_of_layer)
/usr/local/lib/python3.7/dist-packages/keras/optimizers/optimizer_v2/adam.py:110: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  super(Adam, self).__init__(name, **kwargs)
Epoch 1/100
16/16 - 1s - loss: 0.7272 - accuracy: 0.5100 - val_loss: 0.6891 - val_accuracy: 0.4980 - 753ms/epoch - 47ms/step
Epoch 2/100
16/16 - 0s - loss: 0.7086 - accuracy: 0.5200 - val_loss: 0.6747 - val_accuracy: 0.4960 - 41ms/epoch - 3ms/step
Epoch 3/100
16/16 - 0s - loss: 0.6930 - accuracy: 0.5260 - val_loss: 0.6618 - val_accuracy: 0.5120 - 45ms/epoch - 3ms/step
Epoch 4/100
16/16 - 0s - loss: 0.6792 - accuracy: 0.5320 - val_loss: 0.6506 - val_accuracy: 0.5240 - 52ms/epoch - 3ms/step
Epoch 5/100
16/16 - 0s - loss: 0.6672 - accuracy: 0.5340 - val_loss: 0.6413 - val_accuracy: 0.5220 - 43ms/epoch - 3ms/step
Epoch 6/100
16/16 - 0s - loss: 0.6567 - accuracy: 0.5380 - val_loss: 0.6335 - val_accuracy: 0.5280 - 41ms/epoch - 3ms/step
Epoch 7/100
16/16 - 0s - loss: 0.6479 - accuracy: 0.5400 - val_loss: 0.6269 - val_accuracy: 0.5360 - 49ms/epoch - 3ms/step
Epoch 8/100
16/16 - 0s - loss: 0.6400 - accuracy: 0.5480 - val_loss: 0.6211 - val_accuracy: 0.5420 - 44ms/epoch - 3ms/step
Epoch 9/100
16/16 - 0s - loss: 0.6333 - accuracy: 0.5540 - val_loss: 0.6165 - val_accuracy: 0.5480 - 46ms/epoch - 3ms/step
Epoch 10/100
16/16 - 0s - loss: 0.6278 - accuracy: 0.5560 - val_loss: 0.6123 - val_accuracy: 0.5500 - 41ms/epoch - 3ms/step
Epoch 11/100
16/16 - 0s - loss: 0.6225 - accuracy: 0.5660 - val_loss: 0.6089 - val_accuracy: 0.5520 - 50ms/epoch - 3ms/step
Epoch 12/100
16/16 - 0s - loss: 0.6184 - accuracy: 0.5720 - val_loss: 0.6056 - val_accuracy: 0.5560 - 48ms/epoch - 3ms/step
Epoch 13/100
16/16 - 0s - loss: 0.6141 - accuracy: 0.5840 - val_loss: 0.6031 - val_accuracy: 0.5660 - 42ms/epoch - 3ms/step
Epoch 14/100
16/16 - 0s - loss: 0.6107 - accuracy: 0.5860 - val_loss: 0.6007 - val_accuracy: 0.5660 - 44ms/epoch - 3ms/step
Epoch 15/100
16/16 - 0s - loss: 0.6075 - accuracy: 0.5900 - val_loss: 0.5988 - val_accuracy: 0.5720 - 51ms/epoch - 3ms/step
Epoch 16/100
16/16 - 0s - loss: 0.6050 - accuracy: 0.5960 - val_loss: 0.5969 - val_accuracy: 0.5860 - 56ms/epoch - 3ms/step
Epoch 17/100
16/16 - 0s - loss: 0.6024 - accuracy: 0.6180 - val_loss: 0.5952 - val_accuracy: 0.5920 - 43ms/epoch - 3ms/step
Epoch 18/100
16/16 - 0s - loss: 0.5999 - accuracy: 0.6200 - val_loss: 0.5937 - val_accuracy: 0.5960 - 45ms/epoch - 3ms/step
Epoch 19/100
16/16 - 0s - loss: 0.5977 - accuracy: 0.6220 - val_loss: 0.5922 - val_accuracy: 0.6060 - 43ms/epoch - 3ms/step
Epoch 20/100
16/16 - 0s - loss: 0.5956 - accuracy: 0.6340 - val_loss: 0.5908 - val_accuracy: 0.6180 - 45ms/epoch - 3ms/step
Epoch 21/100
16/16 - 0s - loss: 0.5935 - accuracy: 0.6340 - val_loss: 0.5894 - val_accuracy: 0.6240 - 51ms/epoch - 3ms/step
Epoch 22/100
16/16 - 0s - loss: 0.5918 - accuracy: 0.6400 - val_loss: 0.5880 - val_accuracy: 0.6340 - 48ms/epoch - 3ms/step
Epoch 23/100
16/16 - 0s - loss: 0.5898 - accuracy: 0.6480 - val_loss: 0.5865 - val_accuracy: 0.6320 - 59ms/epoch - 4ms/step
Epoch 24/100
16/16 - 0s - loss: 0.5880 - accuracy: 0.6520 - val_loss: 0.5851 - val_accuracy: 0.6400 - 45ms/epoch - 3ms/step
Epoch 25/100
16/16 - 0s - loss: 0.5863 - accuracy: 0.6580 - val_loss: 0.5838 - val_accuracy: 0.6440 - 52ms/epoch - 3ms/step
Epoch 26/100
16/16 - 0s - loss: 0.5845 - accuracy: 0.6640 - val_loss: 0.5823 - val_accuracy: 0.6500 - 46ms/epoch - 3ms/step
Epoch 27/100
16/16 - 0s - loss: 0.5828 - accuracy: 0.6700 - val_loss: 0.5808 - val_accuracy: 0.6480 - 51ms/epoch - 3ms/step
Epoch 28/100
16/16 - 0s - loss: 0.5810 - accuracy: 0.6760 - val_loss: 0.5793 - val_accuracy: 0.6600 - 53ms/epoch - 3ms/step
Epoch 29/100
16/16 - 0s - loss: 0.5793 - accuracy: 0.6800 - val_loss: 0.5778 - val_accuracy: 0.6680 - 41ms/epoch - 3ms/step
Epoch 30/100
16/16 - 0s - loss: 0.5775 - accuracy: 0.6800 - val_loss: 0.5762 - val_accuracy: 0.6720 - 51ms/epoch - 3ms/step
Epoch 31/100
16/16 - 0s - loss: 0.5758 - accuracy: 0.6820 - val_loss: 0.5746 - val_accuracy: 0.6740 - 52ms/epoch - 3ms/step
Epoch 32/100
16/16 - 0s - loss: 0.5740 - accuracy: 0.6860 - val_loss: 0.5730 - val_accuracy: 0.6780 - 42ms/epoch - 3ms/step
Epoch 33/100
16/16 - 0s - loss: 0.5723 - accuracy: 0.6900 - val_loss: 0.5714 - val_accuracy: 0.6820 - 46ms/epoch - 3ms/step
Epoch 34/100
16/16 - 0s - loss: 0.5705 - accuracy: 0.6900 - val_loss: 0.5696 - val_accuracy: 0.6840 - 51ms/epoch - 3ms/step
Epoch 35/100
16/16 - 0s - loss: 0.5687 - accuracy: 0.6900 - val_loss: 0.5679 - val_accuracy: 0.6860 - 46ms/epoch - 3ms/step
Epoch 36/100
16/16 - 0s - loss: 0.5669 - accuracy: 0.6940 - val_loss: 0.5661 - val_accuracy: 0.6860 - 44ms/epoch - 3ms/step
Epoch 37/100
16/16 - 0s - loss: 0.5650 - accuracy: 0.6940 - val_loss: 0.5643 - val_accuracy: 0.6880 - 53ms/epoch - 3ms/step
Epoch 38/100
16/16 - 0s - loss: 0.5631 - accuracy: 0.6960 - val_loss: 0.5626 - val_accuracy: 0.6920 - 45ms/epoch - 3ms/step
Epoch 39/100
16/16 - 0s - loss: 0.5613 - accuracy: 0.7000 - val_loss: 0.5607 - val_accuracy: 0.6980 - 47ms/epoch - 3ms/step
Epoch 40/100
16/16 - 0s - loss: 0.5593 - accuracy: 0.7020 - val_loss: 0.5589 - val_accuracy: 0.6980 - 46ms/epoch - 3ms/step
Epoch 41/100
16/16 - 0s - loss: 0.5575 - accuracy: 0.7040 - val_loss: 0.5569 - val_accuracy: 0.7000 - 52ms/epoch - 3ms/step
Epoch 42/100
16/16 - 0s - loss: 0.5554 - accuracy: 0.7060 - val_loss: 0.5551 - val_accuracy: 0.7020 - 43ms/epoch - 3ms/step
Epoch 43/100
16/16 - 0s - loss: 0.5535 - accuracy: 0.7100 - val_loss: 0.5532 - val_accuracy: 0.7040 - 51ms/epoch - 3ms/step
Epoch 44/100
16/16 - 0s - loss: 0.5515 - accuracy: 0.7180 - val_loss: 0.5513 - val_accuracy: 0.7040 - 52ms/epoch - 3ms/step
Epoch 45/100
16/16 - 0s - loss: 0.5495 - accuracy: 0.7180 - val_loss: 0.5494 - val_accuracy: 0.7080 - 50ms/epoch - 3ms/step
Epoch 46/100
16/16 - 0s - loss: 0.5475 - accuracy: 0.7200 - val_loss: 0.5475 - val_accuracy: 0.7160 - 50ms/epoch - 3ms/step
Epoch 47/100
16/16 - 0s - loss: 0.5453 - accuracy: 0.7260 - val_loss: 0.5455 - val_accuracy: 0.7200 - 49ms/epoch - 3ms/step
Epoch 48/100
16/16 - 0s - loss: 0.5433 - accuracy: 0.7340 - val_loss: 0.5435 - val_accuracy: 0.7240 - 48ms/epoch - 3ms/step
Epoch 49/100
16/16 - 0s - loss: 0.5412 - accuracy: 0.7400 - val_loss: 0.5414 - val_accuracy: 0.7280 - 46ms/epoch - 3ms/step
Epoch 50/100
16/16 - 0s - loss: 0.5390 - accuracy: 0.7420 - val_loss: 0.5393 - val_accuracy: 0.7300 - 44ms/epoch - 3ms/step
Epoch 51/100
16/16 - 0s - loss: 0.5368 - accuracy: 0.7440 - val_loss: 0.5372 - val_accuracy: 0.7300 - 43ms/epoch - 3ms/step
Epoch 52/100
16/16 - 0s - loss: 0.5347 - accuracy: 0.7480 - val_loss: 0.5352 - val_accuracy: 0.7320 - 47ms/epoch - 3ms/step
Epoch 53/100
16/16 - 0s - loss: 0.5324 - accuracy: 0.7480 - val_loss: 0.5329 - val_accuracy: 0.7380 - 42ms/epoch - 3ms/step
Epoch 54/100
16/16 - 0s - loss: 0.5302 - accuracy: 0.7520 - val_loss: 0.5308 - val_accuracy: 0.7400 - 43ms/epoch - 3ms/step
Epoch 55/100
16/16 - 0s - loss: 0.5280 - accuracy: 0.7560 - val_loss: 0.5286 - val_accuracy: 0.7400 - 52ms/epoch - 3ms/step
Epoch 56/100
16/16 - 0s - loss: 0.5257 - accuracy: 0.7600 - val_loss: 0.5265 - val_accuracy: 0.7400 - 44ms/epoch - 3ms/step
Epoch 57/100
16/16 - 0s - loss: 0.5234 - accuracy: 0.7620 - val_loss: 0.5244 - val_accuracy: 0.7480 - 41ms/epoch - 3ms/step
Epoch 58/100
16/16 - 0s - loss: 0.5210 - accuracy: 0.7640 - val_loss: 0.5221 - val_accuracy: 0.7520 - 40ms/epoch - 3ms/step
Epoch 59/100
16/16 - 0s - loss: 0.5188 - accuracy: 0.7640 - val_loss: 0.5199 - val_accuracy: 0.7520 - 77ms/epoch - 5ms/step
Epoch 60/100
16/16 - 0s - loss: 0.5164 - accuracy: 0.7640 - val_loss: 0.5177 - val_accuracy: 0.7520 - 115ms/epoch - 7ms/step
Epoch 61/100
16/16 - 0s - loss: 0.5140 - accuracy: 0.7640 - val_loss: 0.5155 - val_accuracy: 0.7520 - 48ms/epoch - 3ms/step
Epoch 62/100
16/16 - 0s - loss: 0.5117 - accuracy: 0.7680 - val_loss: 0.5131 - val_accuracy: 0.7540 - 44ms/epoch - 3ms/step
Epoch 63/100
16/16 - 0s - loss: 0.5091 - accuracy: 0.7720 - val_loss: 0.5110 - val_accuracy: 0.7560 - 54ms/epoch - 3ms/step
Epoch 64/100
16/16 - 0s - loss: 0.5067 - accuracy: 0.7780 - val_loss: 0.5087 - val_accuracy: 0.7600 - 41ms/epoch - 3ms/step
Epoch 65/100
16/16 - 0s - loss: 0.5042 - accuracy: 0.7760 - val_loss: 0.5065 - val_accuracy: 0.7600 - 60ms/epoch - 4ms/step
Epoch 66/100
16/16 - 0s - loss: 0.5016 - accuracy: 0.7800 - val_loss: 0.5043 - val_accuracy: 0.7660 - 96ms/epoch - 6ms/step
Epoch 67/100
16/16 - 0s - loss: 0.4989 - accuracy: 0.7820 - val_loss: 0.5021 - val_accuracy: 0.7660 - 62ms/epoch - 4ms/step
Epoch 68/100
16/16 - 0s - loss: 0.4963 - accuracy: 0.7860 - val_loss: 0.5000 - val_accuracy: 0.7680 - 61ms/epoch - 4ms/step
Epoch 69/100
16/16 - 0s - loss: 0.4938 - accuracy: 0.7920 - val_loss: 0.4978 - val_accuracy: 0.7700 - 112ms/epoch - 7ms/step
Epoch 70/100
16/16 - 0s - loss: 0.4912 - accuracy: 0.7920 - val_loss: 0.4957 - val_accuracy: 0.7700 - 138ms/epoch - 9ms/step
Epoch 71/100
16/16 - 0s - loss: 0.4887 - accuracy: 0.7920 - val_loss: 0.4937 - val_accuracy: 0.7700 - 174ms/epoch - 11ms/step
Epoch 72/100
16/16 - 0s - loss: 0.4862 - accuracy: 0.7940 - val_loss: 0.4917 - val_accuracy: 0.7720 - 69ms/epoch - 4ms/step
Epoch 73/100
16/16 - 0s - loss: 0.4837 - accuracy: 0.7960 - val_loss: 0.4897 - val_accuracy: 0.7720 - 82ms/epoch - 5ms/step
Epoch 74/100
16/16 - 0s - loss: 0.4812 - accuracy: 0.7980 - val_loss: 0.4877 - val_accuracy: 0.7740 - 102ms/epoch - 6ms/step
Epoch 75/100
16/16 - 0s - loss: 0.4787 - accuracy: 0.7980 - val_loss: 0.4857 - val_accuracy: 0.7760 - 86ms/epoch - 5ms/step
Epoch 76/100
16/16 - 0s - loss: 0.4762 - accuracy: 0.8000 - val_loss: 0.4836 - val_accuracy: 0.7760 - 65ms/epoch - 4ms/step
Epoch 77/100
16/16 - 0s - loss: 0.4736 - accuracy: 0.8020 - val_loss: 0.4816 - val_accuracy: 0.7780 - 107ms/epoch - 7ms/step
Epoch 78/100
16/16 - 0s - loss: 0.4712 - accuracy: 0.8040 - val_loss: 0.4796 - val_accuracy: 0.7800 - 101ms/epoch - 6ms/step
Epoch 79/100
16/16 - 0s - loss: 0.4686 - accuracy: 0.8080 - val_loss: 0.4777 - val_accuracy: 0.7800 - 83ms/epoch - 5ms/step
Epoch 80/100
16/16 - 0s - loss: 0.4663 - accuracy: 0.8080 - val_loss: 0.4758 - val_accuracy: 0.7820 - 103ms/epoch - 6ms/step
Epoch 81/100
16/16 - 0s - loss: 0.4637 - accuracy: 0.8080 - val_loss: 0.4739 - val_accuracy: 0.7840 - 117ms/epoch - 7ms/step
Epoch 82/100
16/16 - 0s - loss: 0.4611 - accuracy: 0.8080 - val_loss: 0.4722 - val_accuracy: 0.7840 - 68ms/epoch - 4ms/step
Epoch 83/100
16/16 - 0s - loss: 0.4587 - accuracy: 0.8080 - val_loss: 0.4705 - val_accuracy: 0.7840 - 95ms/epoch - 6ms/step
Epoch 84/100
16/16 - 0s - loss: 0.4561 - accuracy: 0.8080 - val_loss: 0.4689 - val_accuracy: 0.7860 - 71ms/epoch - 4ms/step
Epoch 85/100
16/16 - 0s - loss: 0.4537 - accuracy: 0.8120 - val_loss: 0.4672 - val_accuracy: 0.7860 - 71ms/epoch - 4ms/step
Epoch 86/100
16/16 - 0s - loss: 0.4512 - accuracy: 0.8120 - val_loss: 0.4657 - val_accuracy: 0.7860 - 140ms/epoch - 9ms/step
Epoch 87/100
16/16 - 0s - loss: 0.4486 - accuracy: 0.8120 - val_loss: 0.4642 - val_accuracy: 0.7860 - 62ms/epoch - 4ms/step
Epoch 88/100
16/16 - 0s - loss: 0.4464 - accuracy: 0.8160 - val_loss: 0.4626 - val_accuracy: 0.7860 - 62ms/epoch - 4ms/step
Epoch 89/100
16/16 - 0s - loss: 0.4438 - accuracy: 0.8180 - val_loss: 0.4612 - val_accuracy: 0.7860 - 66ms/epoch - 4ms/step
Epoch 90/100
16/16 - 0s - loss: 0.4416 - accuracy: 0.8220 - val_loss: 0.4598 - val_accuracy: 0.7880 - 65ms/epoch - 4ms/step
Epoch 91/100
16/16 - 0s - loss: 0.4393 - accuracy: 0.8220 - val_loss: 0.4585 - val_accuracy: 0.7880 - 93ms/epoch - 6ms/step
Epoch 92/100
16/16 - 0s - loss: 0.4371 - accuracy: 0.8240 - val_loss: 0.4571 - val_accuracy: 0.7880 - 104ms/epoch - 7ms/step
Epoch 93/100
16/16 - 0s - loss: 0.4348 - accuracy: 0.8240 - val_loss: 0.4558 - val_accuracy: 0.7880 - 113ms/epoch - 7ms/step
Epoch 94/100
16/16 - 0s - loss: 0.4326 - accuracy: 0.8240 - val_loss: 0.4545 - val_accuracy: 0.7880 - 73ms/epoch - 5ms/step
Epoch 95/100
16/16 - 0s - loss: 0.4304 - accuracy: 0.8280 - val_loss: 0.4532 - val_accuracy: 0.7920 - 78ms/epoch - 5ms/step
Epoch 96/100
16/16 - 0s - loss: 0.4283 - accuracy: 0.8340 - val_loss: 0.4519 - val_accuracy: 0.7920 - 124ms/epoch - 8ms/step
Epoch 97/100
16/16 - 0s - loss: 0.4260 - accuracy: 0.8360 - val_loss: 0.4506 - val_accuracy: 0.7940 - 123ms/epoch - 8ms/step
Epoch 98/100
16/16 - 0s - loss: 0.4239 - accuracy: 0.8360 - val_loss: 0.4495 - val_accuracy: 0.7940 - 63ms/epoch - 4ms/step
Epoch 99/100
16/16 - 0s - loss: 0.4218 - accuracy: 0.8380 - val_loss: 0.4482 - val_accuracy: 0.7940 - 65ms/epoch - 4ms/step
Epoch 100/100
16/16 - 0s - loss: 0.4198 - accuracy: 0.8380 - val_loss: 0.4471 - val_accuracy: 0.7920 - 128ms/epoch - 8ms/step

可视化结果

绘制训练曲线

[ ]:
#绘制训练曲线
plot_loss_accuracy(history,
                       f'Training curve',
                       f'Training curve.pdf')
../../_images/1stPart_Homework.1_MLP_17_0.png

输出模型结构信息

[ ]:
#打印模型结构和参数信息
Fully_Model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input_1 (InputLayer)        [(None, 2)]               0

 0th-hidden (Dense)          (None, 2)                 6

 1th-hidden (Dense)          (None, 2)                 6

 2th-hidden (Dense)          (None, 2)                 6

 3th-hidden (Dense)          (None, 2)                 6

 activation (Dense)          (None, 2)                 6

=================================================================
Total params: 30
Trainable params: 30
Non-trainable params: 0
_________________________________________________________________

对原始特征空间剖分的可视化

[ ]:
#对原始特征空间剖分的可视化

prob = Fully_Model.predict(p)[:,1]
fig, ax1= plt.subplots(1,1, figsize=(7, 4),subplot_kw = {'aspect':'equal'})
ax1.scatter(*p.T,c = prob,cmap = cm_bright)
mp = ax1.scatter(*X.T,c = y,cmap = cm_bright,edgecolors='white',s = 20,linewidths = 0.5)
plt.colorbar(mp,ax = [ax1])
plt.title(f'Outputs of MLP')
plt.savefig(f'空间剖分结果.png')
plt.savefig(f'空间剖分结果.pdf')
79/79 [==============================] - 0s 2ms/step
../../_images/1stPart_Homework.1_MLP_21_1.png

更换背景显示原始样本

[ ]:
#更换背景显示原始样本

mpl.style.use('ggplot')

fig = plt.figure(figsize = (9,3))
top = cm.get_cmap('Oranges_r', 512)
bottom = cm.get_cmap('Blues', 512)
newcolors = np.vstack((top(np.linspace(0.55, 1, 512)),
                       bottom(np.linspace(0, 0.75, 512))))
cm_bright = ListedColormap(newcolors, name='OrangeBlue')

fig = plt.figure(figsize = (8,6))
m3 = plt.scatter(*X.T,c = y,cmap = cm_bright,edgecolors='white',s = 20,linewidths = 0.5)
plt.title(f'Raw data ({n_samples} points)')
plt.axis('equal')
plt.colorbar()
plt.savefig(f'Raw data ({n_samples} points)')
plt.savefig(f'Raw data ({n_samples} points).pdf')
plt.axis('equal')
#plt.colorbar(ax = ax)
plt.show()
<Figure size 648x216 with 0 Axes>
../../_images/1stPart_Homework.1_MLP_23_1.png

特征变换过程可视化

[ ]:
#对样本点特征变换的可视化

hidden_layers = []
for i in range(number_of_layers):
    hidden_layer = Fully_Model.get_layer(f'{i}th-hidden')
    hidden_layers.append(hidden_layer)

inp = Fully_Model.input
outputs = [layer.output for layer in hidden_layers]
print(outputs)
functors = [K.function([inp], [out]) for out in outputs]
MLP_outs = [func([X]) for func in functors]

#可视化经过每一层隐层特征变换后样本点的状态
for idx in range(len(MLP_outs)):
    fig = plt.figure(figsize = (8,6))
    scatter(MLP_outs[idx][0][:,0],MLP_outs[idx][0][:,1],
            c = y,cmap = cm_bright,edgecolors='white',s = 30,linewidths = 0.1)
    plt.axis('equal')
    plt.title(f'Outputs of {idx+1}th hidden layer')
    plt.colorbar()
    plt.savefig(f'Outputs of {idx+1}th hidden layer.png')
    plt.savefig(f'Outputs of {idx+1}th hidden layer.pdf')
    plt.show()

#生成动图
def create_gif(image_list, gif_name, duration=1):
    frames = []
    for image_name in image_list:
        frames.append(imageio.imread(image_name))
    imageio.mimsave(gif_name, frames, 'GIF', duration=duration)
    return

def main():
    image_list = [f'Raw data ({n_samples} points).png']
    for i in range(len(MLP_outs)):
        image_list.append(f'Outputs of {i+1}th hidden layer.png')
    gif_name = '特征变换动图.gif'
    duration = 0.8
    create_gif(image_list, gif_name, duration)

main()
[<KerasTensor: shape=(None, 2) dtype=float32 (created by layer '0th-hidden')>, <KerasTensor: shape=(None, 2) dtype=float32 (created by layer '1th-hidden')>, <KerasTensor: shape=(None, 2) dtype=float32 (created by layer '2th-hidden')>, <KerasTensor: shape=(None, 2) dtype=float32 (created by layer '3th-hidden')>]
../../_images/1stPart_Homework.1_MLP_25_1.png
../../_images/1stPart_Homework.1_MLP_25_2.png
../../_images/1stPart_Homework.1_MLP_25_3.png
../../_images/1stPart_Homework.1_MLP_25_4.png