JavaScript機器學習之KNN算法

Leo09L 7年前發布 | 26K 次閱讀 KNN 算法 JavaScript開發 JavaScript

上圖使用 plot.ly 所畫。

上次我們用JavaScript實現了 線性規劃 ,這次我們來聊聊KNN算法。

KNN是 k-Nearest-Neighbours 的縮寫,它是一種監督學習算法。KNN算法可以用來做分類,也可以用來解決回歸問題。

GitHub倉庫: machine-learning-with-js

KNN算法簡介

簡單地說, KNN算法由那離自己最近的K個點來投票決定待分類數據歸為哪一類 。

如果待分類的數據有這些鄰近數據, NY : 7 , NJ : 0 , IN : 4 ,即它有7個 NY 鄰居,0個 NJ 鄰居,4個 IN 鄰居,則這個數據應該歸類為 NY

假設你在郵局工作,你的任務是為郵遞員分配信件,目標是最小化到各個社區的投遞旅程。不妨假設一共有7個街區。這就是一個實際的分類問題。你需要將這些信件分類,決定它屬于哪個社區,比如 上東城曼哈頓下城 等。

最壞的方案是隨意分配信件分配給郵遞員,這樣每個郵遞員會拿到各個社區的信件。

最佳的方案是根據信件地址進行分類,這樣每個郵遞員只需要負責鄰近社區的信件。

也許你是這樣想的:”將鄰近3個街區的信件分配給同一個郵遞員”。這時,鄰近街區的個數就是 k 。你可以不斷增加 k ,直到獲得最佳的分配方案。這個 k 就是分類問題的最佳值。

KNN代碼實現

上次 一樣,我們將使用 mljsKNN 模塊 ml-knn 來實現。

每一個機器學習算法都需要數據,這次我將使用 IRIS數據集 。其數據集包含了150個樣本,都屬于 鳶尾屬 下的三個亞屬,分別是 山鳶尾變色鳶尾維吉尼亞鳶尾 。四個特征被用作樣本的定量分析,它們分別是 花萼花瓣 的長度和寬度。

1. 安裝模塊

$npm install ml-knn@2.0.0 csvtojson prompt

ml-knn : k-Nearest-Neighbours 模塊,不同版本的接口可能不同,這篇博客使用了2.0.0

csvtojson : 用于將CSV數據轉換為JSON

prompt : 在控制臺輸入輸出數據

2. 初始化并導入數據

IRIS數據集 由加州大學歐文分校提供。

curl https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data > iris.csv

假設你已經初始化了一個NPM項目,請在 index.js 中輸入以下內容:

const KNN = require('ml-knn');
const csv = require('csvtojson');
const prompt = require('prompt');

var knn;

const csvFilePath = 'iris.csv'; // 數據集
const names = ['sepalLength', 'sepalWidth', 'petalLength', 'petalWidth', 'type'];

let seperationSize; // 分割訓練和測試數據

let data = [],
 X = [],
 y = [];

let trainingSetX = [],
 trainingSetY = [],
 testSetX = [],
 testSetY = [];
  • seperationSize 用于分割數據和測試數據

使用csvtojson模塊的fromFile方法加載數據:

csv(
 {
 noheader: true,
 headers: names
 })
 .fromFile(csvFilePath)
 .on('json', (jsonObj) =>
 {
 data.push(jsonObj); // 將數據集轉換為JS對象數組
 })
 .on('done', (error) =>
 {
 seperationSize = 0.7 * data.length;
 data = shuffleArray(data);
 dressData();
 });

我們將 seperationSize 設為樣本數目的0.7倍。注意,如果訓練數據集太小的話,分類效果將變差。

由于數據集是根據種類排序的,所以需要使用 shuffleArray 函數對數據進行混淆,這樣才能方便分割出訓練數據。這個函數的定義請參考StackOverflow的提問 How to randomize (shuffle) a JavaScript array? :

function shuffleArray(array)
{
 for (var i = array.length - 1; i > 0; i--)
 {
 var j = Math.floor(Math.random() * (i + 1));
 var temp = array[i];
 array[i] = array[j];
 array[j] = temp;
 }
 return array;
}

3. 轉換數據

數據集中每一條數據可以轉換為一個JS對象:

{
sepalLength: ‘5.1’,
sepalWidth: ‘3.5’,
petalLength: ‘1.4’,
petalWidth: ‘0.2’,
type: ‘Iris-setosa’ 
}

在使用 KNN 算法訓練數據之前,需要對數據進行這些處理:

  1. 將屬性(sepalLength, sepalWidth,petalLength,petalWidth)由字符串轉換為浮點數. ( parseFloat )
  2. 將分類 (type)用數字表示
function dressData()
{
 let types = new Set(); 
 data.forEach((row) =>
 {
 types.add(row.type);
 });
 let typesArray = [...types]; 

 data.forEach((row) =>
 {
 let rowArray, typeNumber;
 rowArray = Object.keys(row).map(key => parseFloat(row[key])).slice(0, 4);
 typeNumber = typesArray.indexOf(row.type); // Convert type(String) to type(Number)

 X.push(rowArray);
 y.push(typeNumber);
 });

 trainingSetX = X.slice(0, seperationSize);
 trainingSetY = y.slice(0, seperationSize);
 testSetX = X.slice(seperationSize);
 testSetY = y.slice(seperationSize);

 train();
}

4. 訓練數據并測試

function train()
{
 knn = new KNN(trainingSetX, trainingSetY,
 {
 k: 7
 });
 test();
}

train方法需要2個必須的參數: 輸入數據,即 花萼花瓣 的長度和寬度;實際分類,即 山鳶尾變色鳶尾維吉尼亞鳶尾 。另外,第三個參數是可選的,用于提供調整 KNN 算法的內部參數。我將 k 參數設為7,其默認值為5。

訓練好模型之后,就可以使用測試數據來檢查準確性了。我們主要對預測出錯的個數比較感興趣。

function test()
{
 const result = knn.predict(testSetX);
 const testSetLength = testSetX.length;
 const predictionError = error(result, testSetY);
 console.log(`Test Set Size = ${testSetLength} and number of Misclassifications = ${predictionError}`);
 predict();
}

比較預測值與真實值,就可以得到出錯個數:

function error(predicted, expected)
{
 let misclassifications = 0;
 for (var index = 0; index < predicted.length; index++)
 {
 if (predicted[index] !== expected[index])
 {
 misclassifications++;
 }
 }
 return misclassifications;
}

5. 進行預測(可選)

任意輸入屬性值,就可以得到預測值

function predict()
{
 let temp = [];
 prompt.start();
 prompt.get(['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'], function(err, result)
 {
 if (!err)
 {
 for (var key in result)
 {
 temp.push(parseFloat(result[key]));
 }
 console.log(`With ${temp} -- type = ${knn.predict(temp)}`);
 }
 });
}

6. 完整程序

完整的程序 index.js 是這樣的:

const KNN = require('ml-knn');
const csv = require('csvtojson');
const prompt = require('prompt');

var knn;

const csvFilePath = 'iris.csv'; // 數據集
const names = ['sepalLength', 'sepalWidth', 'petalLength', 'petalWidth', 'type'];

let seperationSize; // 分割訓練和測試數據

let data = [],
 X = [],
 y = [];

let trainingSetX = [],
 trainingSetY = [],
 testSetX = [],
 testSetY = [];


csv(
 {
 noheader: true,
 headers: names
 })
 .fromFile(csvFilePath)
 .on('json', (jsonObj) =>
 {
 data.push(jsonObj); // 將數據集轉換為JS對象數組
 })
 .on('done', (error) =>
 {
 seperationSize = 0.7 * data.length;
 data = shuffleArray(data);
 dressData();
 });

function dressData()
{
 let types = new Set(); 
 data.forEach((row) =>
 {
 types.add(row.type);
 });
 let typesArray = [...types]; 

 data.forEach((row) =>
 {
 let rowArray, typeNumber;
 rowArray = Object.keys(row).map(key => parseFloat(row[key])).slice(0, 4);
 typeNumber = typesArray.indexOf(row.type); // Convert type(String) to type(Number)

 X.push(rowArray);
 y.push(typeNumber);
 });

 trainingSetX = X.slice(0, seperationSize);
 trainingSetY = y.slice(0, seperationSize);
 testSetX = X.slice(seperationSize);
 testSetY = y.slice(seperationSize);

 train();
}


// 使用KNN算法訓練數據
function train()
{
 knn = new KNN(trainingSetX, trainingSetY,
 {
 k: 7
 });
 test();
}


// 測試訓練的模型
function test()
{
 const result = knn.predict(testSetX);
 const testSetLength = testSetX.length;
 const predictionError = error(result, testSetY);
 console.log(`Test Set Size = ${testSetLength} and number of Misclassifications = ${predictionError}`);
 predict();
}


// 計算出錯個數
function error(predicted, expected)
{
 let misclassifications = 0;
 for (var index = 0; index < predicted.length; index++)
 {
 if (predicted[index] !== expected[index])
 {
 misclassifications++;
 }
 }
 return misclassifications;
}


// 根據輸入預測結果
function predict()
{
 let temp = [];
 prompt.start();
 prompt.get(['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'], function(err, result)
 {
 if (!err)
 {
 for (var key in result)
 {
 temp.push(parseFloat(result[key]));
 }
 console.log(`With ${temp} -- type = ${knn.predict(temp)}`);
 }
 });
}


// 混淆數據集的順序
function shuffleArray(array)
{
 for (var i = array.length - 1; i > 0; i--)
 {
 var j = Math.floor(Math.random() * (i + 1));
 var temp = array[i];
 array[i] = array[j];
 array[j] = temp;
 }
 return array;
}

在控制臺執行 node index.js

$ node index.js

輸出如下:

Test Set Size = 45 and number of Misclassifications = 2
prompt: Sepal Length: 1.7
prompt: Sepal Width: 2.5
prompt: Petal Length: 0.5
prompt: Petal Width: 3.4
With 1.7,2.5,0.5,3.4 -- type = 2

參考鏈接

歡迎加入 我們Fundebug全棧BUG監控交流群: 622902485

來自:https://kiwenlau.com/2017/07/10/javascript-machine-learning-knn/

 

 本文由用戶 Leo09L 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
 轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
 本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!