使用Python訓練SVM模型識別手寫體數字
支持向量機SVM(Support Vector Machine)是有監督的分類預測模型,本篇文章使用機器學習庫scikit-learn中的手寫數字數據集介紹使用Python對SVM模型進行訓練并對手寫數字進行識別的過程。
準備工作
手寫數字識別的原理是將數字的圖片分割為8X8的灰度值矩陣,將這64個灰度值作為每個數字的訓練集對模型進行訓練。手寫數字所對應的真實數字作為分類結果。在機器學習sklearn庫中已經包含了不同數字的8X8灰度值矩陣,因此我們首先導入sklearn庫自帶的datasets數據集。然后是交叉驗證庫,SVM分類算法庫,繪制圖表庫等。
#導入自帶數據集 from sklearn import datasets #導入交叉驗證庫 from sklearn import cross_validation #導入SVM分類算法庫 from sklearn import svm #導入圖表庫 import matplotlib.pyplot as plt #生成預測結果準確率的混淆矩陣 from sklearn import metrics
讀取并查看數字矩陣
從sklearn庫自帶的datasets數據集中讀取數字的8X8矩陣信息并賦值給digits。
#讀取自帶數據集并賦值給digits digits = datasets.load_digits()
查看其中的數字9可以發現,手寫的數字9以64個灰度值保存。從下面的8×8矩陣中很難看出這是數字9。
#查看數據集中數字9的矩陣 digits.data[9]
以灰度值的方式輸出手寫數字9的圖像,可以看出個大概輪廓。這就是經過切割并以灰度保存的手寫數字9。它所對應的64個灰度值就是模型的訓練集,而真實的數字9是目標分類。我們的模型所要做的就是在已知64個灰度值與每個數字對應關系的情況下,通過對模型進行訓練來對新的手寫數字對應的真實數字進行分類。
#繪制圖表查看數據集中數字9的圖像 plt.imshow(digits.images[9], cmap=plt.cm.gray_r, interpolation='nearest') plt.title('digits.target[9]') plt.show()
設置模型的特征X和預測目標Y
查看數據集中的分類目標,可以看到一共有10個分類,分布為0-9。我們將這個分類目標賦值給Y,作為模型的預測目標。
#數據集中的目標分類 digits.target
#將數據集中的目標賦給Y Y=digits.target
手寫數字的64個灰度值作為特征賦值給X,這里需要說明的是64個灰度值是以8×8矩陣的形式保持的,因此我們需要使用reshape函數重新調整矩陣的行列數。這里也就是將8×8的兩維數據轉換為64×1的一維數據。
#使用reshape函數對矩陣進行轉換,并賦值給X n_samples = len(digits.images) X = digits.images.reshape((n_samples, 64))
查看特征值X和預測目標Y的行數,共有1797行,也就是說數據集中共有1797個手寫數字的圖像,64列是經過我們轉化后的灰度值。
#查看X和Y的行數 X.shape,Y.shape
將數據分割為訓練集和測試集
將1797個手寫數字的灰度值采用隨機抽樣的方法分割為訓練集和測試集,其中訓練集為60%,測試集為40%。
#隨機抽取生成訓練集和測試集,其中訓練集的比例為60%,測試集40% X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, Y, test_size=0.4, random_state=0)
查看分割后的測試集數據,共有1078條數據。這些數據將用來訓練SVM模型。
#查看訓練集的行數 X_train.shape,y_train.shape
對SVM模型進行訓練
將訓練集數據X_train和y_train代入到SVM模型中,對模型進行訓練。下面是具體的代碼和結果。
#生成SVM分類模型 clf = svm.SVC(gamma=0.001)
#使用訓練集對svm分類模型進行訓練 clf.fit(X_train, y_train)
使用測試集測對模型進行測試
使用測試集數據X_test和y_test對訓練后的SVM模型進行檢驗,模型對手寫數字分類的準確率為99.3%。這是非常高的準確率。那么是否真的這么靠譜嗎?下面我們來單獨測試下。
#使用測試集衡量分類模型準確率 clf.score(X_test, y_test)
我們使用測試集的特征X,也就是每個手寫數字的64個灰度值代入到模型中,讓SVM模型進行分類。
#對測試集數據進行預測 predicted=clf.predict(X_test)
然后查看前20個手寫數字的分類結果,也就是手寫數字所對應的真實數字。下面是具體的分類結果。
#查看前20個測試集的預測結果 predicted[:20]
再查看訓練集中前20個分類結果,也就是真實數字的情況,并將之前的分類結果與測試集的真實結果進行對比。
#查看測試集中的真實結果 expected=y_test
以下是測試集中前20個真實數字的結果,與前面SVM模型的分類結果對比,前20個結果是一致的。
#查看測試集中前20個真實結果 expected[:20]
使用混淆矩陣來看下SVM模型對所有測試集數據的預測與真實結果的準確率情況,下面是一個10X10的矩陣,左上角第一行第一個數字60表示實際為0,SVM模型也預測為0的個數,第一行第二個數字表示實際為0,SVM模型預測為1的數字。第二行第二個數字73表示實際為1,SVM模型也預測為1的個數。
#生成準確率的混淆矩陣(Confusion matrix) metrics.confusion_matrix(expected, predicted)
從混淆矩陣中可以看到,大部分的數字SVM的分類和預測都是正確的,但也有個別的數字分類錯誤,例如真實的數字2,SVM模型有一次錯誤的分類為1,還有一次錯誤分類為7。
以下為scikit-learn官方的案例說明及代碼。
http://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html
來自:http://bluewhale.cc/2016-09-06/python-svm-recognizing-hand-written-digits.html