K-means算法(Spark Demo)

ye34 9年前發布 | 10K 次閱讀 Java 算法

import java.util.Random
import spark.SparkContext
import spark.SparkContext.
import spark.examples.Vector.

object SparkKMeans { /**

 * line -> vector
 */

def parseVector (line: String) : Vector = { return new Vector (line.split (' ').map (_.toDouble) ) }

/**
 * 計算該節點的最近中心節點
 */

def closestCenter (p: Vector, centers: Array[Vector]) : Int = { var bestIndex = 0 var bestDist = p.squaredDist (centers (0) ) //差平方之和 for (i < - 1 until centers.length) { val dist = p.squaredDist (centers (i) ) if (dist < bestDist) { bestDist = dist bestIndex = i } } return bestIndex }

def main (args: Array[String]) { if (args.length < 3) { System.err.println ("Usage: SparkKMeans <master> <file> <dimensions> <k> <iters>") System.exit (1) } val sc = new SparkContext (args (0), "SparkKMeans") val lines = sc.textFile (args (1), args (5).toInt) val points = lines.map (parseVector (_) ).cache() //文本中每行為一個節點,再將每個節點轉換成Vector val dimensions = args (2).toInt //節點的維度 val k = args (3).toInt //聚類個數 val iterations = args (4).toInt //迭代次數

                                                     // 隨機初始化k個中心節點
                                                     val rand = new Random (42)
    var centers = new Array[Vector] (k)
    for (i < - 0 until k)
        centers (i) = Vector (dimensions, _ => 2 * rand.nextDouble - 1)
                      println ("Initial centers: " + centers.mkString (", ") )
                      val time1 = System.currentTimeMillis()
        for (i < - 1 to iterations) {
            println ("On iteration " + i)

            // Map each point to the index of its closest center and a (point, 1) pair
            // that we will use to compute an average later
            val mappedPoints = points.map { p => (closestCenter (p, centers), (p, 1) ) }

            val newCenters = mappedPoints.reduceByKey {
            case ( (sum1, count1), (sum2, count2) ) => (sum1 + sum2, count1 + count2) //(向量相加, 計數器相加)
                } .map {
            case (id, (sum, count) ) => (id, sum / count) //根據前面的聚類,重新計算中心節點的位置
                } .collect

            // 更新中心節點
            for ( (id, value) < - newCenters) {
                centers (id) = value
            }
        }
                   val time2 = System.currentTimeMillis()
                               println ("Final centers: " + centers.mkString (", ") + ", time: " + (time2 - time1) )
}

}</pre>

 本文由用戶 ye34 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
 轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
 本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!