CART分類回歸樹算法

jopen 9年前發布 | 70K 次閱讀 算法

CART分類回歸樹算法

與上次文章中提到的ID3算法和C4.5算法類似,CART算法也是一種決策樹分類算法。CART分類回歸樹算法的本質也是對數據進行分類的,最終數據的表現形式也是以樹形的模式展現的,與ID3,C4.5算法不同的是,他的分類標準所采用的算法不同了。下面列出了其中的一些不同之處:

1、CART最后形成的樹是一個二叉樹,每個節點會分成2個節點,左孩子節點和右孩子節點,而在ID3和C4.5中是按照分類屬性的值類型進行劃分,于是這就要求CART算法在所選定的屬性中又要劃分出最佳的屬性劃分值,節點如果選定了劃分屬性名稱還要確定里面按照那個值做一個二元的劃分。

2、CART算法對于屬性的值采用的是基于Gini系數值的方式做比較,gini某個屬性的某次值的劃分的gini指數的值為:

,pk就是分別為正負實例的概率,gini系數越小說明分類純度越高,可以想象成與熵的定義一樣。因此在最后計算的時候我們只取其中值最小的做出劃分。最后做比較的時候用的是gini的增益做比較,要對分類號的數據做出一個帶權重的gini指數的計算。舉一個網上的一個例子:


比如體溫為恒溫時包含哺乳類5個、鳥類2個,則:

體溫為非恒溫時包含爬行類3個、魚類3個、兩棲類2個,則

所以如果按照“體溫為恒溫和非恒溫”進行劃分的話,我們得到GINI的增益(類比信息增益):

最好的劃分就是使得GINI_Gain最小的劃分。

通過比較每個屬性的最小的gini指數值,作為最后的結果。

3、CART算法在把數據進行分類之后,會對樹進行一個剪枝,常用的用前剪枝和后剪枝法,而常見的后剪枝發包括代價復雜度剪枝,悲觀誤差剪枝等等,我寫的此次算法采用的是代價復雜度剪枝法。代價復雜度剪枝的算法公式為:

α表示的是每個非葉子節點的誤差增益率,可以理解為誤差代價,最后選出誤差代價最小的一個節點進行剪枝。

里面變量的意思為:

是子樹中包含的葉子節點個數;

是節點t的誤差代價,如果該節點被剪枝;

r(t)是節點t的誤差率;

p(t)是節點t上的數據占所有數據的比例。

是子樹Tt的誤差代價,如果該節點不被剪枝。它等于子樹Tt上所有葉子節點的誤差代價之和。下面說說我對于這個公式的理解:其實這個公式的本質是對于剪枝前和剪枝后的樣本偏差率做一個差值比較,一個好的分類當然是分類后的樣本偏差率相較于沒分類(就是剪枝掉的時候)的偏差率小,所以這時的值就會大,如果分類前后基本變化不大,則意味著分類不起什么效果,α值的分子位置就小,所以誤差代價就小,可以被剪枝。但是一般分類后的偏差率會小于分類前的,因為偏差數在高層節點的時候肯定比子節點的多,子節點偏差數最多與父親節點一樣。

CART算法實現

首先是程序的備用數據,我是把他存在了一個文字中,通過程序進行逐行的讀取:

Rid Age Income Student CreditRating BuysComputer
1 Youth High No Fair No
2 Youth High No Excellent No
3 MiddleAged High No Fair Yes
4 Senior Medium No Fair Yes
5 Senior Low Yes Fair Yes
6 Senior Low Yes Excellent No
7 MiddleAged Low Yes Excellent Yes
8 Youth Medium No Fair No
9 Youth Low Yes Fair Yes
10 Senior Medium Yes Fair Yes
11 Youth Medium Yes Excellent Yes
12 MiddleAged Medium No Excellent Yes
13 MiddleAged High Yes Fair Yes
14 Senior Medium No Excellent No
下面是主程序,里面有具體的注釋:


package DataMing_CART;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Queue;

import javax.lang.model.element.NestingKind;
import javax.swing.text.DefaultEditorKit.CutAction;
import javax.swing.text.html.MinimalHTMLWriter;

/**
 * CART分類回歸樹算法工具類
 * 
 * @author lyq
 * 
 */
public class CARTTool {
    // 類標號的值類型
    private final String YES = "Yes";
    private final String NO = "No";

    // 所有屬性的類型總數,在這里就是data源數據的列數
    private int attrNum;
    private String filePath;
    // 初始源數據,用一個二維字符數組存放模仿表格數據
    private String[][] data;
    // 數據的屬性行的名字
    private String[] attrNames;
    // 每個屬性的值所有類型
    private HashMap<String, ArrayList<String>> attrValue;

    public CARTTool(String filePath) {
        this.filePath = filePath;
        attrValue = new HashMap<>();
    }

    /**
     * 從文件中讀取數據
     */
    public void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        data = new String[dataArray.size()][];
        dataArray.toArray(data);
        attrNum = data[0].length;
        attrNames = data[0];

        /*
         * for (int i = 0; i < data.length; i++) { for (int j = 0; j <
         * data[0].length; j++) { System.out.print(" " + data[i][j]); }
         * System.out.print("\n"); }
         */

    }

    /**
     * 首先初始化每種屬性的值的所有類型,用于后面的子類熵的計算時用
     */
    public void initAttrValue() {
        ArrayList<String> tempValues;

        // 按照列的方式,從左往右找
        for (int j = 1; j < attrNum; j++) {
            // 從一列中的上往下開始尋找值
            tempValues = new ArrayList<>();
            for (int i = 1; i < data.length; i++) {
                if (!tempValues.contains(data[i][j])) {
                    // 如果這個屬性的值沒有添加過,則添加
                    tempValues.add(data[i][j]);
                }
            }

            // 一列屬性的值已經遍歷完畢,復制到map屬性表中
            attrValue.put(data[0][j], tempValues);
        }

        /*
         * for (Map.Entry entry : attrValue.entrySet()) {
         * System.out.println("key:value " + entry.getKey() + ":" +
         * entry.getValue()); }
         */
    }

    /**
     * 計算機基尼指數
     * 
     * @param remainData
     *            剩余數據
     * @param attrName
     *            屬性名稱
     * @param value
     *            屬性值
     * @param beLongValue
     *            分類是否屬于此屬性值
     * @return
     */
    public double computeGini(String[][] remainData, String attrName,
            String value, boolean beLongValue) {
        // 實例總數
        int total = 0;
        // 正實例數
        int posNum = 0;
        // 負實例數
        int negNum = 0;
        // 基尼指數
        double gini = 0;

        // 還是按列從左往右遍歷屬性
        for (int j = 1; j < attrNames.length; j++) {
            // 找到了指定的屬性
            if (attrName.equals(attrNames[j])) {
                for (int i = 1; i < remainData.length; i++) {
                    // 統計正負實例按照屬于和不屬于值類型進行劃分
                    if ((beLongValue && remainData[i][j].equals(value))
                            || (!beLongValue && !remainData[i][j].equals(value))) {
                        if (remainData[i][attrNames.length - 1].equals(YES)) {
                            // 判斷此行數據是否為正實例
                            posNum++;
                        } else {
                            negNum++;
                        }
                    }
                }
            }
        }

        total = posNum + negNum;
        double posProbobly = (double) posNum / total;
        double negProbobly = (double) negNum / total;
        gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly;

        // 返回計算基尼指數
        return gini;
    }

    /**
     * 計算屬性劃分的最小基尼指數,返回最小的屬性值劃分和最小的基尼指數,保存在一個數組中
     * 
     * @param remainData
     *            剩余誰
     * @param attrName
     *            屬性名稱
     * @return
     */
    public String[] computeAttrGini(String[][] remainData, String attrName) {
        String[] str = new String[2];
        // 最終該屬性的劃分類型值
        String spiltValue = "";
        // 臨時變量
        int tempNum = 0;
        // 保存屬性的值劃分時的最小的基尼指數
        double minGini = Integer.MAX_VALUE;
        ArrayList<String> valueTypes = attrValue.get(attrName);
        // 屬于此屬性值的實例數
        HashMap<String, Integer> belongNum = new HashMap<>();

        for (String string : valueTypes) {
            // 重新計數的時候,數字歸0
            tempNum = 0;
            // 按列從左往右遍歷屬性
            for (int j = 1; j < attrNames.length; j++) {
                // 找到了指定的屬性
                if (attrName.equals(attrNames[j])) {
                    for (int i = 1; i < remainData.length; i++) {
                        // 統計正負實例按照屬于和不屬于值類型進行劃分
                        if (remainData[i][j].equals(string)) {
                            tempNum++;
                        }
                    }
                }
            }

            belongNum.put(string, tempNum);
        }

        double tempGini = 0;
        double posProbably = 1.0;
        double negProbably = 1.0;
        for (String string : valueTypes) {
            tempGini = 0;

            posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1);
            negProbably = 1 - posProbably;

            tempGini += posProbably
                    * computeGini(remainData, attrName, string, true);
            tempGini += negProbably
                    * computeGini(remainData, attrName, string, false);

            if (tempGini < minGini) {
                minGini = tempGini;
                spiltValue = string;
            }
        }

        str[0] = spiltValue;
        str[1] = minGini + "";

        return str;
    }

    public void buildDecisionTree(AttrNode node, String parentAttrValue,
            String[][] remainData, ArrayList<String> remainAttr,
            boolean beLongParentValue) {
        // 屬性劃分值
        String valueType = "";
        // 劃分屬性名稱
        String spiltAttrName = "";
        double minGini = Integer.MAX_VALUE;
        double tempGini = 0;
        // 基尼指數數組,保存了基尼指數和此基尼指數的劃分屬性值
        String[] giniArray;

        if (beLongParentValue) {
            node.setParentAttrValue(parentAttrValue);
        } else {
            node.setParentAttrValue("!" + parentAttrValue);
        }

        if (remainAttr.size() == 0) {
            if (remainData.length > 1) {
                ArrayList<String> indexArray = new ArrayList<>();
                for (int i = 1; i < remainData.length; i++) {
                    indexArray.add(remainData[i][0]);
                }
                node.setDataIndex(indexArray);
            }
            System.out.println("attr remain null");
            return;
        }

        for (String str : remainAttr) {
            giniArray = computeAttrGini(remainData, str);
            tempGini = Double.parseDouble(giniArray[1]);

            if (tempGini < minGini) {
                spiltAttrName = str;
                minGini = tempGini;
                valueType = giniArray[0];
            }
        }
        // 移除劃分屬性
        remainAttr.remove(spiltAttrName);
        node.setAttrName(spiltAttrName);

        // 孩子節點,分類回歸樹中,每次二元劃分,分出2個孩子節點
        AttrNode[] childNode = new AttrNode[2];
        String[][] rData;

        boolean[] bArray = new boolean[] { true, false };
        for (int i = 0; i < bArray.length; i++) {
            // 二元劃分屬于屬性值的劃分
            rData = removeData(remainData, spiltAttrName, valueType, bArray[i]);

            boolean sameClass = true;
            ArrayList<String> indexArray = new ArrayList<>();
            for (int k = 1; k < rData.length; k++) {
                indexArray.add(rData[k][0]);
                // 判斷是否為同一類的
                if (!rData[k][attrNames.length - 1]
                        .equals(rData[1][attrNames.length - 1])) {
                    // 只要有1個不相等,就不是同類型的
                    sameClass = false;
                    break;
                }
            }

            childNode[i] = new AttrNode();
            if (!sameClass) {
                // 創建新的對象屬性,對象的同個引用會出錯
                ArrayList<String> rAttr = new ArrayList<>();
                for (String str : remainAttr) {
                    rAttr.add(str);
                }
                buildDecisionTree(childNode[i], valueType, rData, rAttr,
                        bArray[i]);
            } else {
                String pAtr = (bArray[i] ? valueType : "!" + valueType);
                childNode[i].setParentAttrValue(pAtr);
                childNode[i].setDataIndex(indexArray);
            }
        }

        node.setChildAttrNode(childNode);
    }

    /**
     * 屬性劃分完畢,進行數據的移除
     * 
     * @param srcData
     *            源數據
     * @param attrName
     *            劃分的屬性名稱
     * @param valueType
     *            屬性的值類型
     * @parame beLongValue 分類是否屬于此值類型
     */
    private String[][] removeData(String[][] srcData, String attrName,
            String valueType, boolean beLongValue) {
        String[][] desDataArray;
        ArrayList<String[]> desData = new ArrayList<>();
        // 待刪除數據
        ArrayList<String[]> selectData = new ArrayList<>();
        selectData.add(attrNames);

        // 數組數據轉化到列表中,方便移除
        for (int i = 0; i < srcData.length; i++) {
            desData.add(srcData[i]);
        }

        // 還是從左往右一列列的查找
        for (int j = 1; j < attrNames.length; j++) {
            if (attrNames[j].equals(attrName)) {
                for (int i = 1; i < desData.size(); i++) {
                    if (desData.get(i)[j].equals(valueType)) {
                        // 如果匹配這個數據,則移除其他的數據
                        selectData.add(desData.get(i));
                    }
                }
            }
        }

        if (beLongValue) {
            desDataArray = new String[selectData.size()][];
            selectData.toArray(desDataArray);
        } else {
            // 屬性名稱行不移除
            selectData.remove(attrNames);
            // 如果是劃分不屬于此類型的數據時,進行移除
            desData.removeAll(selectData);
            desDataArray = new String[desData.size()][];
            desData.toArray(desDataArray);
        }

        return desDataArray;
    }

    public void startBuildingTree() {
        readDataFile();
        initAttrValue();

        ArrayList<String> remainAttr = new ArrayList<>();
        // 添加屬性,除了最后一個類標號屬性
        for (int i = 1; i < attrNames.length - 1; i++) {
            remainAttr.add(attrNames[i]);
        }

        AttrNode rootNode = new AttrNode();
        buildDecisionTree(rootNode, "", data, remainAttr, false);
        setIndexAndAlpah(rootNode, 0, false);
        System.out.println("剪枝前:");
        showDecisionTree(rootNode, 1);
        setIndexAndAlpah(rootNode, 0, true);
        System.out.println("\n剪枝后:");
        showDecisionTree(rootNode, 1);
    }

    /**
     * 顯示決策樹
     * 
     * @param node
     *            待顯示的節點
     * @param blankNum
     *            行空格符,用于顯示樹型結構
     */
    private void showDecisionTree(AttrNode node, int blankNum) {
        System.out.println();
        for (int i = 0; i < blankNum; i++) {
            System.out.print("    ");
        }
        System.out.print("--");
        // 顯示分類的屬性值
        if (node.getParentAttrValue() != null
                && node.getParentAttrValue().length() > 0) {
            System.out.print(node.getParentAttrValue());
        } else {
            System.out.print("--");
        }
        System.out.print("--");

        if (node.getDataIndex() != null && node.getDataIndex().size() > 0) {
            String i = node.getDataIndex().get(0);
            System.out.print("【" + node.getNodeIndex() + "】類別:"
                    + data[Integer.parseInt(i)][attrNames.length - 1]);
            System.out.print("[");
            for (String index : node.getDataIndex()) {
                System.out.print(index + ", ");
            }
            System.out.print("]");
        } else {
            // 遞歸顯示子節點
            System.out.print("【" + node.getNodeIndex() + ":"
                    + node.getAttrName() + "】");
            if (node.getChildAttrNode() != null) {
                for (AttrNode childNode : node.getChildAttrNode()) {
                    showDecisionTree(childNode, 2 * blankNum);
                }
            } else {
                System.out.print("【  Child Null】");
            }
        }
    }

    /**
     * 為節點設置序列號,并計算每個節點的誤差率,用于后面剪枝
     * 
     * @param node
     *            開始的時候傳入的是根節點
     * @param index
     *            開始的索引號,從1開始
     * @param ifCutNode
     *            是否需要剪枝
     */
    private void setIndexAndAlpah(AttrNode node, int index, boolean ifCutNode) {
        AttrNode tempNode;
        // 最小誤差代價節點,即將被剪枝的節點
        AttrNode minAlphaNode = null;
        double minAlpah = Integer.MAX_VALUE;
        Queue<AttrNode> nodeQueue = new LinkedList<AttrNode>();

        nodeQueue.add(node);
        while (nodeQueue.size() > 0) {
            index++;
            // 從隊列頭部獲取首個節點
            tempNode = nodeQueue.poll();
            tempNode.setNodeIndex(index);
            if (tempNode.getChildAttrNode() != null) {
                for (AttrNode childNode : tempNode.getChildAttrNode()) {
                    nodeQueue.add(childNode);
                }
                computeAlpha(tempNode);
                if (tempNode.getAlpha() < minAlpah) {
                    minAlphaNode = tempNode;
                    minAlpah = tempNode.getAlpha();
                } else if (tempNode.getAlpha() == minAlpah) {
                    // 如果誤差代價值一樣,比較包含的葉子節點個數,剪枝有多葉子節點數的節點
                    if (tempNode.getLeafNum() > minAlphaNode.getLeafNum()) {
                        minAlphaNode = tempNode;
                    }
                }
            }
        }

        if (ifCutNode) {
            // 進行樹的剪枝,讓其左右孩子節點為null
            minAlphaNode.setChildAttrNode(null);
        }
    }

    /**
     * 為非葉子節點計算誤差代價,這里的后剪枝法用的是CCP代價復雜度剪枝
     * 
     * @param node
     *            待計算的非葉子節點
     */
    private void computeAlpha(AttrNode node) {
        double rt = 0;
        double Rt = 0;
        double alpha = 0;
        // 當前節點的數據總數
        int sumNum = 0;
        // 最少的偏差數
        int minNum = 0;

        ArrayList<String> dataIndex;
        ArrayList<AttrNode> leafNodes = new ArrayList<>();

        addLeafNode(node, leafNodes);
        node.setLeafNum(leafNodes.size());
        for (AttrNode attrNode : leafNodes) {
            dataIndex = attrNode.getDataIndex();

            int num = 0;
            sumNum += dataIndex.size();
            for (String s : dataIndex) {
                // 統計分類數據中的正負實例數
                if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) {
                    num++;
                }
            }
            minNum += num;

            // 取小數量的值部分
            if (1.0 * num / dataIndex.size() > 0.5) {
                num = dataIndex.size() - num;
            }

            rt += (1.0 * num / (data.length - 1));
        }

        //同樣取出少偏差的那部分
        if (1.0 * minNum / sumNum > 0.5) {
            minNum = sumNum - minNum;
        }

        Rt = 1.0 * minNum / (data.length - 1);
        alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1);
        node.setAlpha(alpha);
    }

    /**
     * 篩選出節點所包含的葉子節點數
     * 
     * @param node
     *            待篩選節點
     * @param leafNode
     *            葉子節點列表容器
     */
    private void addLeafNode(AttrNode node, ArrayList<AttrNode> leafNode) {
        ArrayList<String> dataIndex;

        if (node.getChildAttrNode() != null) {
            for (AttrNode childNode : node.getChildAttrNode()) {
                dataIndex = childNode.getDataIndex();
                if (dataIndex != null && dataIndex.size() > 0) {
                    // 說明此節點為葉子節點
                    leafNode.add(childNode);
                } else {
                    // 如果還是非葉子節點則繼續遞歸調用
                    addLeafNode(childNode, leafNode);
                }
            }
        }
    }

}
AttrNode節點的設計和屬性:


/**
 * 回歸分類樹節點
 * 
 * @author lyq
 * 
 */
public class AttrNode {
    // 節點屬性名字
    private String attrName;
    // 節點索引標號
    private int nodeIndex;
    //包含的葉子節點數
    private int leafNum;
    // 節點誤差率
    private double alpha;
    // 父親分類屬性值
    private String parentAttrValue;
    // 孩子節點
    private AttrNode[] childAttrNode;
    // 數據記錄索引
    private ArrayList<String> dataIndex;
    .....
get,set方法自行補上。客戶端的場景調用:


package DataMing_CART;

public class Client {
    public static void main(String[] args){
        String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";

        CARTTool tool = new CARTTool(filePath);

        tool.startBuildingTree();
    }
}
數據文件路徑自行修改,否則會報錯(特殊情況懶得處理了.....)。最后程序的輸出結果,請自行從左往右看,從上往下,左邊的是父親節點,上面的是考前的子節點:


剪枝前:

    --!--【1:Age】
        --MiddleAged--【2】類別:Yes[3, 7, 12, 13, ]
        --!MiddleAged--【3:Student】
                --No--【4:Income】
                                --High--【6】類別:No[1, 2, ]
                                --!High--【7:CreditRating】
                                                                --Fair--【10】類別:Yes[4, 8, ]
                                                                --!Fair--【11】類別:No[14, ]
                --!No--【5:CreditRating】
                                --Fair--【8】類別:Yes[5, 9, 10, ]
                                --!Fair--【9:Income】
                                                                --Medium--【12】類別:Yes[11, ]
                                                                --!Medium--【13】類別:No[6, ]
剪枝后:

    --!--【1:Age】
        --MiddleAged--【2】類別:Yes[3, 7, 12, 13, ]
        --!MiddleAged--【3:Student】
                --No--【4:Income】【  Child Null】
                --!No--【5:CreditRating】
                                --Fair--【8】類別:Yes[5, 9, 10, ]
                                --!Fair--【9:Income】
                                                                --Medium--【12】類別:Yes[11, ]
                                                                --!Medium--【13】類別:No[6, ]


結果分析:

我在一開始的時候根據的是最后分類的數據是否為同一個類標識的,如果都為YES或者都為NO的,分類終止,一般情況下都說的通,但是如果最后屬性劃分完畢了,剩余的數據還有存在類標識不一樣的情況就會偏差,比如說這里的7號CredaRating節點,下面Fair分支中的[4,8]就不是同類的。所以在后面的剪枝算法就被剪枝了。因為后面的4和7號節點的誤差代價率為0,說明分類前后沒有類偏差變化,這也見證了后剪枝算法的威力所在了。

在coding遇到的困難和改進的地方:

1、先說說在編碼時遇到的困難,在對節點進行賦索引標號值的時候出了問題,因為是之前生成樹的時候采用了DFS的思想,如果編號時也采用此方法就不對了,于是就用到了把節點取出放入隊列這樣的遍歷方式,就是BFS的方式為節點標號。

2、程序的一個改進的地方在于算一個非葉子節點的時候需要計算他所包含的葉子節點數,采用了從當前節點開始從上往下遞歸計算,而且每個非葉子節點都計算一遍,顯然這樣做的效率是不高,后來想到了一種從葉子節點開始計算,從下往上直到根節點,對父親節點的非葉子節點列表做更新操作,就只要計算一次,這有點dp的思想在里面了,由于時間關系,沒有來得及實現。

3、第二個優化點就是后剪枝算法的多樣化,我這里采用的是CCP代價復雜度算法,大家可以試著實現其他的諸如悲觀誤差算法進行剪枝,看看能不能把程序中4和7號節點識別出來,并且剪枝掉。

最后聲明一下,用了里面的一點點的公式和例子,參考資料:http://blog.csdn.net/tianguokaka/article/details/9018933

 本文由用戶 jopen 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
 轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
 本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!