文章出處

仍然是 動手學嘗試學習系列的筆記,原文見:多類邏輯回歸 — 從0開始 。 這篇的主要目的,是從一堆服飾圖片中,通過機器學習識別出每個服飾圖片對應的分類是什么(比如:一個看起來象短袖上衣的圖片,應該歸類到T-Shirt分類)

示例代碼如下,這篇的代碼略復雜,分成幾個步驟解讀:

 

一、下載數據,并顯示圖片及標簽

 1 from mxnet import gluon
 2 from mxnet import ndarray as nd
 3 import matplotlib.pyplot as plt
 4 import mxnet as mx
 5 from mxnet import autograd
 6 
 7 def transform(data, label):
 8     return data.astype('float32')/255, label.astype('float32')
 9 
10 #訓練數據集(需聯網下載,網速慢時,會很卡)
11 mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)
12 
13 #測試數據集(需聯網下載)
14 mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)
15 
16 # data, label = mnist_train[0]
17 # ('example shape: ', data.shape, 'label:', label)
18 
19 #顯示服飾圖片
20 def show_images(images):
21     n = images.shape[0]
22     _, figs = plt.subplots(1, n, figsize=(15, 15))
23     for i in range(n):
24         figs[i].imshow(images[i].reshape((28, 28)).asnumpy())
25         figs[i].axes.get_xaxis().set_visible(False)
26         figs[i].axes.get_yaxis().set_visible(False)
27     plt.show()
28 
29 #獲取圖片對應分類標簽文本
30 def get_text_labels(label):
31     text_labels = [
32         'T 恤', '長 褲', '套頭衫', '裙 子', '外 套',
33         '涼 鞋', '襯 衣', '運動鞋', '包 包', '短 靴'
34     ]
35     return [text_labels[int(i)] for i in label]
36 
37 #下面這些代碼,用于輔助大家理解示例圖片數據集內部結構
38 # tup1 = mnist_train[0:1] #取出訓練集的第1個樣本
39 # print(type(tup1)) #<class 'tuple'> 可以看出這是個元組類型
40 # print(len(tup1)) #2 有2個元素
41 # print(type(tup1[0])) #<class 'mxnet.ndarray.ndarray.NDArray'> 第1個元素是一個矩陣
42 # print(type(tup1[1])) #<class 'numpy.ndarray'> 第2個元素是numpy的矩陣
43 # print(tup1[0].shape) #(1, 28, 28, 1) 第1個元素是一個四維矩陣,用來存儲每張圖中的像素點對應的值,最后1維表示RGB通道,這里只取了1個通道
44 # print(tup1[1].shape) #(1,) 第2個元素用于表示圖片對應的文本分類的索引值
45 # print(tup1[0]) #打印第1個元素(即:四維矩陣的值),<NDArray 1x28x28x1 @cpu(0)> 結果太長,就不列在注釋里了
46 # print(tup1[1]) #[2.],打印第2個元素(即:該圖片對應的分類索引數值)
47 # print(get_text_labels(tup1[1])) #顯示分類索引值對應的文本['pullover']
48 
49 #取出訓練集中的圖片數據,以及圖片標簽索引值
50 data, label = mnist_train[0:10]
51 
52 #打印數據集的相關信息
53 print('example shape: ', data.shape, 'label:', label)
54 
55 #顯示圖片
56 show_images(data)
57 
58 #打印圖片分類標簽
59 print(get_text_labels(label))
View Code

首次運行時,可能會很久都沒有反應,讓人誤以為代碼有問題,其實背后在聯網下載數據,去睡會兒,等醒來的時候,估計就下載好了~_~,下載的數據會保存在~/.mxnet/datasets/fashion-mnist目錄(mac環境):

下載完成后,上面的代碼會將圖片數據解析并顯示出來,類似下面這樣:

 

二、讀取數據并初始化參數

 1 #批量讀取數據
 2 batch_size = 256
 3 #訓練集
 4 train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)
 5 #測試集
 6 test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False)
 7 
 8 #每張圖片的像素用向量表示,就是28*28的長度,即:784
 9 num_inputs = 784
10 #要預測10張圖片,即:輸出結果長度為10的向量
11 num_outputs = 10
12 
13 #初始化權重W、偏置b參數矩陣
14 W = nd.random_normal(shape=(num_inputs, num_outputs))
15 b = nd.random_normal(shape=num_outputs)
16 
17 params = [W, b]
18 
19 #附加梯度,方便后面用梯度下降法計算
20 for param in params:
21     param.attach_grad()
View Code

這與之前的 機器學習筆記(1):線性回歸 很類似,不再重復解釋 

 

三、創建模型

 1 #歸一化函數
 2 def softmax(X):
 3     exp = nd.exp(X)
 4     partition = exp.sum(axis=1, keepdims=True)
 5     return exp / partition
 6 
 7 #計算模型(仍然是類似y=w.x+b的方程)
 8 def net(X):
 9     return softmax(nd.dot(X.reshape((-1, num_inputs)), W) + b)
10 
11 #損失函數(使用交叉熵函數)
12 def cross_entropy(yhat, y):
13     return - nd.pick(nd.log(yhat), y)
14 
15 #梯度下降法
16 def SGD(params, lr):
17     for param in params:
18         param[:] = param - lr * param.grad
View Code

其中softmax(歸一化)及交叉熵cross_entropy,詳情可參考上篇:歸一化(softmax)、信息熵、交叉熵

 

四、如何評估準確度

 1 #計算準確度
 2 def accuracy(output, label):
 3     return nd.mean(output.argmax(axis=1) == label).asscalar()
 4 
 5 def _get_batch(batch):
 6     if isinstance(batch, mx.io.DataBatch):
 7         data = batch.data[0]
 8         label = batch.label[0]
 9     else:
10         data, label = batch
11     return data, label
12 
13 #評估準確度
14 def evaluate_accuracy(data_iterator, net):
15     acc = 0.
16     if isinstance(data_iterator, mx.io.MXDataIter):
17         data_iterator.reset()
18     for i, batch in enumerate(data_iterator):
19         data, label = _get_batch(batch)
20         output = net(data)
21         acc += accuracy(output, label)
22     return acc / (i+1)
View Code

機器學習的效果如何,通常要有一個評價值,上面的函數就是用來估計算法和模型準確度的。

注: 這里面用到了二個新的函數mean,argmax 解釋一下

mean類似sql中的avg函數,就是求平均值,即把一個矩陣的所有元數加起來,然后除以元數個數

from mxnet import ndarray as nd
x = nd.array([1,2,3,4,5,6]);
print(x,x.mean(),(1+2+3+4+5+6)/6.0)

輸出如下:

[ 1.  2.  3.  4.  5.  6.]
<NDArray 6 @cpu(0)> 
[ 3.5]
<NDArray 1 @cpu(0)> 3.5

而argmax,是找出(指定軸向)最大值的索引下標

from mxnet import ndarray as nd
x = nd.array([1,4,7,3,6])
print(x.argmax(axis=0))

輸出為[ 2.],即:第3列數字7最大。再來個多維矩陣的

如上圖,多維矩陣時,如果指定axis=0,表示軸的方向是縱向(自上而下),顯然第1列中的最大值7在第2行(即:row_index是1),第2列的最大值9在第3行(即:row_index=2),類推第3列的最大值8在第1行(row_index=0),最終輸出的結果就是[1, 2, 0]

如果把axis指定為1,則軸的方向為橫向(自左向右),如下圖:

axis為1時,輸出的索引,為列下標(即:第幾列),顯然8在第2列,7在第0列,9在第1列。

現在我們來想一下:為啥argmax結合mean這二個函數,可以用來評估準確度?

答案:預測的結果也是一個矩陣,通常預測對了,該元素值為1,預測錯誤則為0。

如上圖,假如有3個指標,預測對了2個,第三行,一個都沒預測對,那么準確率為2/3,即0.6666左右

 

五、訓練

 1 #學習率
 2 learning_rate = .1
 3 
 4 #開始訓練
 5 for epoch in range(5):
 6     train_loss = 0.
 7     train_acc = 0.
 8     for data, label in train_data:
 9         with autograd.record():
10             output = net(data)
11             loss = cross_entropy(output, label)
12         loss.backward()
13         SGD(params, learning_rate / batch_size)
14         train_loss += nd.mean(loss).asscalar()
15         train_acc += accuracy(output, label)
16 
17     test_acc = evaluate_accuracy(test_data, net)
18     print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
19         epoch, train_loss / len(train_data), train_acc / len(train_data), test_acc))
View Code

訓練過程與之前的機器學習筆記(1):線性回歸 套路一樣,參看之前的即可。

 

六、顯示預測結果

1 #顯示結果    
2 data, label = mnist_test[0:10]
3 show_images(data)
4 print('true labels')
5 print(get_text_labels(label))
6 
7 predicted_labels = net(data).argmax(axis=1)
8 print('predicted labels')
9 print(get_text_labels(predicted_labels.asnumpy()))
View Code

運行結果,參考下圖:

可以看到損失函數的計算值在一直下降(即:計算在收斂),最終的結果中紅線部分為100%預測正確的,其它一些外形相似的分類:襯衣、T恤、套頭衫、外套 這些都是"有袖子類的上衣",并沒有完全預測正確,但整體方向還是對的(即:并沒有把"上衣"識別成"鞋子"或"包包"等明顯不靠譜的分類),最終的模型、算法及參數有待進一步提高。


文章列表


不含病毒。www.avast.com
arrow
arrow
    全站熱搜
    創作者介紹
    創作者 大師兄 的頭像
    大師兄

    IT工程師數位筆記本

    大師兄 發表在 痞客邦 留言(0) 人氣()