使用Go 機器學習庫來進行數據分析 (決策樹)
決策樹和隨機森林
決策樹是機器學習中最接近人類思考問題的過程的一種算法,通過若干個節點,對特征進行提問并分類(可以是二分類也可以使多分類),直至最后生成葉節點(也就是只剩下一種屬性)。
每個決策樹都表述了一種樹型結構,它由它的分支來對該類型的對象依靠屬性進行分類。每個決策樹可以依靠對源數據庫的分割進行數據測試。這個過程可以遞歸式的對樹進行修剪。 當不能再進行分割或一個單獨的類可以被應用于某一分支時,遞歸過程就完成了。另外,隨機森林分類器將許多決策樹結合起來以提升分類的正確率。
golearn支持兩種決策樹算法。ID3和RandomTree。
-
ID3: 以信息增益為準則選擇信息增益最大的屬性。
ID3 is a decision tree induction algorithm which splits on the Attribute which gives the greatest Information Gain (entropy gradient). It performs well on categorical data. Numeric datasets will need to be discretised before using ID3
-
RandomTree: 與ID3類似,但是選擇的屬性的時候隨機選擇。
Random Trees are structurally identical to those generated by ID3, but the split Attribute is chosen randomly. Golearn's implementation allows you to choose up to k nodes for consideration at each split.
可以參考 ChongmingLiu的介紹: 決策樹(ID3 & C4.5 & CART) 。
維基百科中對隨機森林的介紹:
在機器學習中,隨機森林是一個包含多個決策樹的分類器,并且其輸出的類別是由個別樹輸出的類別的眾數而定。 Leo Breiman和Adele Cutler發展出推論出隨機森林的算法。而"Random Forests"是他們的商標。這個術語是1995年由貝爾實驗室的Tin Kam Ho所提出的隨機決策森林(random decision forests)而來的。這個方法則是結合Breimans的"Bootstrap aggregating"想法和Ho的"random subspace method" 以建造決策樹的集合。
在機器學習中,隨機森林由許多的決策樹組成,因為這些決策樹的形成采用了隨機的方法,因此也叫做隨機決策樹。隨機森林中的樹之間是沒有關聯的。當測試數據進入隨機森林時,其實就是讓每一顆決策樹進行分類,最后取所有決策樹中分類結果最多的那類為最終的結果。因此隨機森林是一個包含多個決策樹的分類器,并且其輸出的類別是由個別樹輸出的類別的眾數而定。
代碼
下面是使用決策樹和隨機森林預測鳶尾花分類的代碼,來自golearn:
// Demonstrates decision tree classification
package main
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/ensemble"
"github.com/sjwhitworth/golearn/evaluation"
"github.com/sjwhitworth/golearn/filters"
"github.com/sjwhitworth/golearn/trees"
"math/rand"
)
func main() {
var tree base.Classifier
rand.Seed(44111342)
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
// Discretise the iris dataset with Chi-Merge
filt := filters.NewChiMergeFilter(iris,0.999)
for _, a := range base.NonClassFloatAttributes(iris) {
filt.AddAttribute(a)
}
filt.Train()
irisf := base.NewLazilyFilteredInstances(iris, filt)
// Create a 60-40 training-test split
trainData, testData := base.InstancesTrainTestSplit(irisf,0.60)
//
// First up, use ID3
//
tree = trees.NewID3DecisionTree(0.6)
// (Parameter controls train-prune split.)
// Train the ID3 tree
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
// Generate predictions
predictions, err := tree.Predict(testData)
if err != nil {
panic(err)
}
// Evaluate
fmt.Println("ID3 Performance (information gain)")
cf, err := evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.InformationGainRatioRuleGenerator))
// (Parameter controls train-prune split.)
// Train the ID3 tree
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
// Generate predictions
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
// Evaluate
fmt.Println("ID3 Performance (information gain ratio)")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.GiniCoefficientRuleGenerator))
// (Parameter controls train-prune split.)
// Train the ID3 tree
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
// Generate predictions
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
// Evaluate
fmt.Println("ID3 Performance (gini index generator)")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
//
// Next up, Random Trees
//
// Consider two randomly-chosen attributes
tree = trees.NewRandomTree(2)
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
fmt.Println("RandomTree Performance")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
//
// Finally, Random Forests
//
tree = ensemble.NewRandomForest(70,3)
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
fmt.Println("RandomForest Performance")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
}
</code></pre>
首選使用 ChiMerge 方法進行數據離散,ChiMerge 是監督的、自底向上的(即基于合并的)數據離散化方法。它依賴于卡方分析:具有最小卡方值的相鄰區間合并在一起,直到滿足確定的停止準則。
基本思想:對于精確的離散化,相對類頻率在一個區間內應當完全一致。因此,如果兩個相鄰的區間具有非常類似的類分布,則這兩個區間可以合并;否則,它們應當保持分開。而低卡方值表明它們具有相似的類分布。可以參考"Principles of Data Mining"第二版中的第105頁--第115頁。
接著調用 base.NewLazilyFilteredInstances 應用filter得到FixedDataGrid。
之后將數據集分成訓練數據和測試數據兩部分。
接下來就是訓練數據、預測與評估了。
分別使用 ID3 、 ID3 with InformationGainRatioRuleGenerator 、 ID3 with GiniCoefficientRuleGenerator 、 RandomTree 、 RandomForest 算法進行處理。
評估結果
以下是各種算法的評估結果,可以和 kNN進行比較,看起來比不過kNN的預測。
ID3 Performance (information gain)
Reference Class True Positives False Positives True Negatives Precision Recall F1 Score
Iris-virginica 32 5 46 0.8649 0.9697 0.9143
Iris-versicolor 4 1 61 0.8000 0.1818 0.2963
Iris-setosa 29 13 42 0.6905 1.0000 0.8169
Overall accuracy: 0.7738
ID3 Performance (information gain ratio)
Reference Class True Positives False Positives True Negatives Precision Recall F1 Score
Iris-virginica 29 3 48 0.9062 0.8788 0.8923
Iris-versicolor 5 3 59 0.6250 0.2273 0.3333
Iris-setosa 29 15 40 0.6591 1.0000 0.7945
Overall accuracy: 0.7500
ID3 Performance (gini index generator)
Reference Class True Positives False Positives True Negatives Precision Recall F1 Score
Iris-virginica 26 5 46 0.8387 0.7879 0.8125
Iris-versicolor 17 36 26 0.3208 0.7727 0.4533
Iris-setosa 0 0 55 NaN 0.0000 NaN
Overall accuracy: 0.5119
RandomTree Performance
Reference Class True Positives False Positives True Negatives Precision Recall F1 Score
Iris-virginica 30 3 48 0.9091 0.9091 0.9091
Iris-versicolor 9 3 59 0.7500 0.4091 0.5294
Iris-setosa 29 10 45 0.7436 1.0000 0.8529
Overall accuracy: 0.8095
RandomForest Performance
Reference Class True Positives False Positives True Negatives Precision Recall F1 Score
Iris-virginica 31 8 43 0.7949 0.9394 0.8611
Iris-versicolor 0 0 62 NaN 0.0000 NaN
Iris-setosa 29 16 39 0.6444 1.0000 0.7838
Overall accuracy: 0.7143
</code></pre>
來自:http://colobu.com/2017/12/07/Two-Machine-Learning-for-Go/