一、強分類器訓練過程
算法原理如下(參考自VIOLA P, JONES M. Robust real time object detection[A] . 8th IEEE International Conference on Computer Vision[C] . Vancouver , 2001.)
- 給定樣本 (x1; y1) , . . . , (xn; yn) ; 其中yi = 0表示負樣本,yi =1表示正樣本;
- 初始化權重:負樣本權重W0i= 1/2m, 正樣本權重W1i = 1/ 2l,其中m為負樣本總數,l為正樣本總數;
- 對于t = 1, ... T(T為訓練次數):
- 權重歸一化,簡單說就是使本輪所有樣本的權重的和為1;
- 根據每一個特征訓練簡單分類器,僅使用一個特征;
- 從所有簡單分類器中選出一個分錯率最低的分類器,為弱分類器;
- 更新權重
- 最后組合T個弱分類器為強分類器
二、代碼實現及說明(python)
目的:訓練得到一個強分類器,該強分類器分錯率低于預設值,且該強分類器由若干個弱分類器(對應單個特征)組成,通過若干個分類器及其權重計算得到的值對樣本進行分類。
def adaBoostTrainDS(dataArr,classLabels,numIt=40): weakClassArr = [] #存放強分類器的所有弱分類器信息 m = shape(dataArr)[0] D = mat(ones((m,1))/m) #權重初始化 aggClassEst = mat(zeros((m,1))) for i in range(numIt): bestStump,error,classEst = buildStump(dataArr,classLabels,D)#根據訓練樣本、權重得到一個弱分類器 print "D:",D.T alpha = float(0.5*log((1.0-error)/max(error,1e-16)))#計算alpha值,該值與分錯率相關,分錯率越小,該值越大,弱分類器權重 #max(error,1e-16)用于確保錯誤為0時不會發生除0溢出 bestStump['alpha'] = alpha weakClassArr.append(bestStump) #存儲該弱分類 print "classEst: ",classEst.T expon = multiply(-1*alpha*mat(classLabels).T,classEst) D = multiply(D,exp(expon)) #重新計算樣本權重 D = D/D.sum() #歸一化 #計算當前強分類器的分錯率,達到預期要求即停止 aggClassEst += alpha*classEst print "aggClassEst: ",aggClassEst.T aggErrors = multiply(sign(aggClassEst) != mat(classLabels).T,ones((m,1))) #計算數據點哪個是錯誤 print 'aggErrors: ',sign(aggClassEst) != mat(classLabels).T print 'aggErrors: ',aggErrors errorRate = aggErrors.sum()/m #計算錯誤率 print "total error: ",errorRate if errorRate == 0.0: break return weakClassArr
三、運行結果
訓練樣本:
datMat = matrix([[ 1. , 2.1, 0.3],
[ 2. , 1.1, 0.4],
[ 1.3, 1. , 1.2],
[ 1. , 1. , 1.1],
[ 2. , 1. , 1.3],
[ 7. , 2. , 0.35]])
classLabels = [1.0, 1.0, 1.0, -1.0, -1.0, -1.0]
訓練得到的強分類器(強分類器分錯率:0%,單個弱分類器最小分錯率為33%,在上一篇已經測試過):
[{'dim': 0, 'ineq': 'gt', 'thresh': 1.6000000000000001, 'alpha': 0.34657359027997275},
{'dim': 1, 'ineq': 'lt', 'thresh': 1.0, 'alpha': 0.5493061443340549},
{'dim': 0, 'ineq': 'gt', 'thresh': 2.2000000000000002, 'alpha': 0.5493061443340549},
{'dim': 2, 'ineq': 'gt', 'thresh': 0.29999999999999999, 'alpha': 0.4777557225137181},
{'dim': 0, 'ineq': 'lt', 'thresh': 1.0, 'alpha': 0.49926441505556346}]
手動計算分類:
針對第一個樣本[ 1. , 2.1,
0.3],利用強分類器計算結果如下:
- 0.34657359027997275
- 0.5493061443340549
-
0.5493061443340549
+
0.4777557225137181
+
0.49926441505556346
= -0.468165741378801--->小于0,正樣本
針對第六個樣本[
7. , 2. , 0.35],利用強分類器計算結果如下:
+ 0.34657359027997275
- 0.5493061443340549
+
0.5493061443340549
+
0.4777557225137181
-
0.49926441505556346
= +0.3250648977381274--->大于0,負樣本
其它樣本的計算類似
結論:
強分類器分類,即通過若干個分類器的權重的正負號計算得出,而正負號是通過該若分類器的閾值判斷得到;
強分類器比弱分類器準確率高。
文章列表