結合PageRank算法用Java實現文本相似度

jopen 10年前發布 | 41K 次閱讀 算法

目標

嘗試了一下把PageRank算法結合了文本相似度計算。直覺上是想把一個list里,和大家都比較靠攏的文本可能最后的PageRank值會比較大。因為如 果最后計算的PageRank值大,說明有比較多的文本和他的相似度值比較高,或者有更多的文本向他靠攏。這樣是不是就可以得到一些相對核心的文本,或者 相對代表性的文本?如果是要在整堆文本里切分一些關鍵的詞做token,那么每個token在每份文本里的權重就可以不一樣,那么是否就可以得到比較核心 的token,來給這些文本打標簽?當然,分詞切詞的時候都是要用工具過濾掉stopword的。

我也只是想嘗試一下這個想法,就簡單實現了整個過程。可能實現上還有問題。我的結果是最后大家的PageRank值都非常接近。如:

    5.626742067958352, 5.626742066427658, 5.626742070495978, 5.626742056768215, 5.626742079766638  

選5,10,20,50都差不多。都非常接近。主 要在設置PageRank定制迭代的那個DISTANCE值,值越接近0,迭代次數越多,經過很多次游走之后,文本之間的關系都很相近,各自的 pagerank值相差不大。如果調成0.5這樣的級別,可能迭代了4次左右就停下來了,互相之間差距會大一些。具體看自己的需求來控制這個距離參數了

 

代碼實現

文本之間的相似度計算用的是余弦距離,先哈希過。下面是計算兩個List<String>的余弦距離代碼:

    package dcd.academic.recommend;

import java.util.ArrayList;  
import java.util.HashMap;  
import java.util.Iterator;  
import java.util.Map;  

import dcd.academic.util.StdOutUtil;  

public class CosineDis {  

    public static double getSimilarity(ArrayList<String> doc1, ArrayList<String> doc2) {  
        if (doc1 != null && doc1.size() > 0 && doc2 != null && doc2.size() > 0) {  

            Map<Long, int[]> AlgorithmMap = new HashMap<Long, int[]>();  

            for (int i = 0; i < doc1.size(); i++) {  
                String d1 = doc1.get(i);  
                long sIndex = hashId(d1);  
                int[] fq = AlgorithmMap.get(sIndex);  
                if (fq != null) {  
                    fq[0]++;  
                } else {  
                    fq = new int[2];  
                    fq[0] = 1;  
                    fq[1] = 0;  
                    AlgorithmMap.put(sIndex, fq);  
                }  
            }  

            for (int i = 0; i < doc2.size(); i++) {  
                String d2 = doc2.get(i);  
                long sIndex = hashId(d2);  
                int[] fq = AlgorithmMap.get(sIndex);  
                if (fq != null) {  
                    fq[1]++;  
                } else {  
                    fq = new int[2];  
                    fq[0] = 0;  
                    fq[1] = 1;  
                    AlgorithmMap.put(sIndex, fq);  
                }  

            }  

            Iterator<Long> iterator = AlgorithmMap.keySet().iterator();  
            double sqdoc1 = 0;  
            double sqdoc2 = 0;  
            double denominator = 0;  
            while (iterator.hasNext()) {  
                int[] c = AlgorithmMap.get(iterator.next());  
                denominator += c[0] * c[1];  
                sqdoc1 += c[0] * c[0];  
                sqdoc2 += c[1] * c[1];  
            }  

            return denominator / Math.sqrt(sqdoc1 * sqdoc2);  
        } else {  
            return 0;  
        }  
    }  

    public static long hashId(String s) {  
        long seed = 131; // 31 131 1313 13131 131313 etc.. BKDRHash  
        long hash = 0;  
        for (int i = 0; i < s.length(); i++) {  
            hash = (hash * seed) + s.charAt(i);  
        }  
        return hash;  
    }  

    public static void main(String[] args) {  
        ArrayList<String> t1 = new ArrayList<String>();  
        ArrayList<String> t2 = new ArrayList<String>();  
        t1.add("sa");  
        t1.add("dfg");  
        t1.add("df");  

        t2.add("gfd");  
        t2.add("sa");  

        StdOutUtil.out(getSimilarity(t1, t2));  
    }  
}  </pre><a href="/misc/goto?guid=4959553561072186348" class="CopyToClipboard" title="copy"></a></div>

</div> </div>

利用上面這個類,根據文本之間的相似度,為每份文本計算得到一個向量(最后要歸一一下),用來初始化PageRank的起始矩陣。我用的數據是我 solr里的論文標題+摘要的文本,我是通過SolrjHelper這個類去取得了一個List<String>。你想替換的話把這部分換成 自己想測試的String列就可以了。下面是讀取數據,生成向量給PageRank類的代碼:

    package dcd.academic.recommend;

import java.io.IOException;  
import java.net.UnknownHostException;  
import java.util.ArrayList;  
import java.util.List;  

import dcd.academic.mongodb.MyMongoClient;  
import dcd.academic.solrj.SolrjHelper;  
import dcd.academic.util.StdOutUtil;  
import dcd.academic.util.StringUtil;  

import com.mongodb.BasicDBList;  
import com.mongodb.BasicDBObject;  
import com.mongodb.DBCollection;  
import com.mongodb.DBCursor;  
import com.mongodb.DBObject;  

public class BtwPublication {  

    public static final int NUM = 20;  

    public static void main(String[] args) throws IOException{  
        BtwPublication bp = new BtwPublication();  
        //bp.updatePublicationForComma();  
        PageRank pageRank = new PageRank(bp.getPagerankS("random"));  
        pageRank.doPagerank();  
    }  

    public double getDist(String pub1, String pub2) throws IOException {  
        if (pub1 != null && pub2 != null) {  
            ArrayList<String> doc1 = StringUtil.getTokens(pub1);  
            ArrayList<String> doc2 = StringUtil.getTokens(pub2);  
            return CosineDis.getSimilarity(doc1, doc2);  
        } else {  
            return 0;  
        }  
    }  

//  public List<Map<String, String>> getPubs(String name) {  
//        
//  }  

    public List<List<Double>> getPagerankS(String text) throws IOException {  
        SolrjHelper helper = new SolrjHelper(1);  
        List<String> pubs = helper.getPubsByTitle(text, 0, NUM);  
        List<List<Double>> s = new ArrayList<List<Double>>();  
        for (String pub : pubs) {  
            List<Double> tmp_row = new ArrayList<Double>();  
            double total = 0.0;  
            for (String other : pubs) {  
                if (!pub.equals(other)) {  
                    double tmp = getDist(pub, other);  
                    tmp_row.add(tmp);  
                    total += tmp;  
                } else {  
                    tmp_row.add(0.0);  
                }  
            }  
            s.add(getNormalizedRow(tmp_row, total));  
        }  
        return s;  
    }  

    public List<Double> getNormalizedRow(List<Double> row, double d) {  
        List<Double> res = new ArrayList<Double>();  
        for (int i = 0; i < row.size(); i ++) {  
            res.add(row.get(i) / d);  
        }  
        StdOutUtil.out(res.toString());  
        return res;  
    }  
}  </pre><a href="/misc/goto?guid=4959553561072186348" class="CopyToClipboard" title="copy"></a></div>

</div> </div> 最后這個是PageRank類,部分參數可以自己再設置:

    package dcd.academic.recommend;

import java.util.ArrayList;  
import java.util.List;  
import java.util.Random;  

import dcd.academic.util.StdOutUtil;  

public class PageRank {  
    private static final double ALPHA = 0.85;  
    private static final double DISTANCE = 0.0000001;  
    private static final double MUL = 10;  

    public static int SIZE;  
    public static List<List<Double>> s;  

    PageRank(List<List<Double>> s) {  
        this.SIZE = s.get(0).size();  
        this.s = s;  
    }  

    public static void doPagerank() {  
        List<Double> q = new ArrayList<Double>();  
        for (int i = 0; i < SIZE; i ++) {  
            q.add(new Random().nextDouble()*MUL);  
        }  
        System.out.println("初始的向量q為:");  
        printVec(q);  
        System.out.println("初始的矩陣G為:");  
        printMatrix(getG(ALPHA));  
        List<Double> pageRank = calPageRank(q, ALPHA);  
        System.out.println("PageRank為:");  
        printVec(pageRank);  
        System.out.println();  
    }  

    /** 
     * 打印輸出一個矩陣 
     *  
     * @param m 
     */  
    public static void printMatrix(List<List<Double>> m) {  
        for (int i = 0; i < m.size(); i++) {  
            for (int j = 0; j < m.get(i).size(); j++) {  
                System.out.print(m.get(i).get(j) + ", ");  
            }  
            System.out.println();  
        }  
    }  

    /** 
     * 打印輸出一個向量 
     *  
     * @param v 
     */  
    public static void printVec(List<Double> v) {  
        for (int i = 0; i < v.size(); i++) {  
            System.out.print(v.get(i) + ", ");  
        }  
        System.out.println();  
    }  

    /** 
     * 獲得一個初始的隨機向量q 
     *  
     * @param n 
     *            向量q的維數 
     * @return 一個隨機的向量q,每一維是0-5之間的隨機數 
     */  
    public static List<Double> getInitQ(int n) {  
        Random random = new Random();  
        List<Double> q = new ArrayList<Double>();  
        for (int i = 0; i < n; i++) {  
            q.add(new Double(5 * random.nextDouble()));  
        }  
        return q;  
    }  

    /** 
     * 計算兩個向量的距離 
     *  
     * @param q1 
     *            第一個向量 
     * @param q2 
     *            第二個向量 
     * @return 它們的距離 
     */  
    public static double calDistance(List<Double> q1, List<Double> q2) {  
        double sum = 0;  

        if (q1.size() != q2.size()) {  
            return -1;  
        }  

        for (int i = 0; i < q1.size(); i++) {  
            sum += Math.pow(q1.get(i).doubleValue() - q2.get(i).doubleValue(),  
                    2);  
        }  
        return Math.sqrt(sum);  
    }  

    /** 
     * 計算pagerank 
     *  
     * @param q1 
     *            初始向量 
     * @param a 
     *            alpha的值 
     * @return pagerank的結果 
     */  
    public static List<Double> calPageRank(List<Double> q1, double a) {  

        List<List<Double>> g = getG(a);  
        List<Double> q = null;  
        while (true) {  
            q = vectorMulMatrix(g, q1);  
            double dis = calDistance(q, q1);  
            System.out.println(dis);  
            if (dis <= DISTANCE) {  
                System.out.println("q1:");  
                printVec(q1);  
                System.out.println("q:");  
                printVec(q);  
                break;  
            }  
            q1 = q;  
        }  
        return q;  
    }  

    /** 
     * 計算獲得初始的G矩陣 
     *  
     * @param a 
     *            為alpha的值,0.85 
     * @return 初始矩陣G 
     */  
    public static List<List<Double>> getG(double a) {  
        List<List<Double>> aS = numberMulMatrix(s, a);  
        List<List<Double>> nU = numberMulMatrix(getU(), (1 - a) / SIZE);  
        List<List<Double>> g = addMatrix(aS, nU);  
        return g;  
    }  

    /** 
     * 計算一個矩陣乘以一個向量 
     *  
     * @param m 
     *            一個矩陣 
     * @param v 
     *            一個向量 
     * @return 返回一個新的向量 
     */  
    public static List<Double> vectorMulMatrix(List<List<Double>> m,  
            List<Double> v) {  
        if (m == null || v == null || m.size() <= 0  
                || m.get(0).size() != v.size()) {  
            return null;  
        }  

        List<Double> list = new ArrayList<Double>();  
        for (int i = 0; i < m.size(); i++) {  
            double sum = 0;  
            for (int j = 0; j < m.get(i).size(); j++) {  
                double temp = m.get(i).get(j).doubleValue()  
                        * v.get(j).doubleValue();  
                sum += temp;  
            }  
            list.add(sum);  
        }  

        return list;  
    }  

    /** 
     * 計算兩個矩陣的和 
     *  
     * @param list1 
     *            第一個矩陣 
     * @param list2 
     *            第二個矩陣 
     * @return 兩個矩陣的和 
     */  
    public static List<List<Double>> addMatrix(List<List<Double>> list1,  
            List<List<Double>> list2) {  
        List<List<Double>> list = new ArrayList<List<Double>>();  
        if (list1.size() != list2.size() || list1.size() <= 0  
                || list2.size() <= 0) {  
            return null;  
        }  
        for (int i = 0; i < list1.size(); i++) {  
            list.add(new ArrayList<Double>());  
            for (int j = 0; j < list1.get(i).size(); j++) {  
                double temp = list1.get(i).get(j).doubleValue()  
                        + list2.get(i).get(j).doubleValue();  
                list.get(i).add(new Double(temp));  
            }  
        }  
        return list;  
    }  

    /** 
     * 計算一個數乘以矩陣 
     *  
     * @param s 
     *            矩陣s 
     * @param a 
     *            double類型的數 
     * @return 一個新的矩陣 
     */  
    public static List<List<Double>> numberMulMatrix(List<List<Double>> s,  
            double a) {  
        List<List<Double>> list = new ArrayList<List<Double>>();  

        for (int i = 0; i < s.size(); i++) {  
            list.add(new ArrayList<Double>());  
            for (int j = 0; j < s.get(i).size(); j++) {  
                double temp = a * s.get(i).get(j).doubleValue();  
                list.get(i).add(new Double(temp));  
            }  
        }  
        return list;  
    }  

    /** 
     * 初始化U矩陣,全1 
     *  
     * @return U 
     */  
    public static List<List<Double>> getU() {  
        List<Double> row = new ArrayList<Double>();  
        for (int i = 0; i < SIZE; i ++) {  
            row.add(new Double(1));  
        }  

        List<List<Double>> s = new ArrayList<List<Double>>();  
        for (int j = 0; j < SIZE; j ++) {  
            s.add(row);  
        }  
        return s;  
    }  
}  </pre><a href="/misc/goto?guid=4959553561072186348" class="CopyToClipboard" title="copy"></a></div>

</div> </div>

下面是我一次實驗結果的數據,我設置了五分文本,這樣看起來比較短:

 
[0.0, 0.09968643574761415, 0.2601130421632277, 0.31094706119099713, 0.32925346089816093]  
[0.1315115598803241, 0.0, 0.23650307622882252, 0.2827229880685279, 0.34926237582232544]  
[0.13521235055030142, 0.09318868159350341, 0.0, 0.3996835314966943, 0.3719154363595009]  
[0.1389453620825689, 0.0957614822411479, 0.34357346750710194, 0.0, 0.4217196881691813]  
[0.14612484353723476, 0.11749453142051332, 0.31752920814285096, 0.4188514168994011, 0.0]  
初始的向量q為:  
8.007763265073303, 3.1232982446687387, 1.1722525763669134, 5.906625842576609, 9.019220483814852,   
初始的矩陣G為:  
0.030000000000000006, 0.11473347038547205, 0.2510960858387436, 0.2943050020123476, 0.30986544176343683,   
0.14178482589827548, 0.030000000000000006, 0.23102761479449913, 0.2703145398582487, 0.3268730194489766,   
0.1449304979677562, 0.10921037935447789, 0.030000000000000006, 0.36973100177219015, 0.3461281209055758,   
0.14810355777018358, 0.11139725990497573, 0.3220374473810367, 0.030000000000000006, 0.38846173494380415,   
0.15420611700664955, 0.12987035170743633, 0.29989982692142336, 0.38602370436449096, 0.030000000000000006,   
8.215210604296416  
2.1786836521210637  
0.6343362349619535  
0.19024536572818584  
0.05836227176176904  
0.018354791916908083  
0.0059297512567364945  
0.0019669982458251243  
6.679891158687752E-4  
2.312017647733628E-4  
8.117199104238135E-5  
2.8787511843006215E-5  
1.0279598478348542E-5  
3.6872987746593366E-6  
1.3264993458811192E-6  
4.780938295685138E-7  
1.7251588746973008E-7  
6.229666266632005E-8  
q1:  
5.62674207030434, 5.626742074589739, 5.626742063777632, 5.626742101012727, 5.626742037269133,   
q:  
5.626742067958352, 5.626742066427658, 5.626742070495978, 5.626742056768215, 5.626742079766638,   
PageRank為:  
5.626742067958352, 5.626742066427658, 5.626742070495978, 5.626742056768215, 5.626742079766638,   
</div> </div> 來自:http://blog.csdn.net/pelick/article/details/8847457

 本文由用戶 jopen 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
 轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
 本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!