KNN演算法屬於監督學習演算法,是一種用於分類的非常簡單的演算法。簡單的說,KNN演算法採用測量不同特徵值之間的距離方法進行分類。具體演算法如下:
1)計算已知類別資料集中的點與當前點之間的距離
2)按照距離遞增次序排序
3)選取與當前距離最小的k個點
4)確定前k個點所在類別的出現頻率
5)返回前k個點出現頻率最高的類別作為當前點的預測分類
這次的資料集來自《機器學習實戰》一書的約會網站配對這一案例。格式如下:
這四列依次為:每年獲得的飛行常客裡程數;玩視頻遊戲所耗時間百分比;每周消費的冰淇淋公升數;以及對該對象的評價。為了將最後一項評價也轉換成數字,我定義的規則為:didntLike為1;smallDoses為2;largeDoses為3;共900條訓練資料,另準備了100條相同格式的測試資料用於最後計算錯誤率。
將這三維資料集任取兩列作得散佈圖如下:
取不同的特徵,最後呈現的效果也不同。但從上面兩圖可以看出,實驗用的資料集呈現聚類現象,利於做分類。
package knn;/** * @author shenchao * * 封裝一條資料 * */public class Data implements Comparable<Data>{/** * 每年獲得的飛行常客裡程數 */private double mile;/** * 玩視頻遊戲所耗時間百分比 */private double time;/** * 每周消費的冰淇淋公升數 */private double icecream;/** * 1 代表不喜歡的人 * 2 代表魅力一般的人 * 3 代表極具魅力的人 */private int type;/** * 兩個資料距離 */private double distance;public double getMile() {return mile;}public void setMile(double mile) {this.mile = mile;}public double getTime() {return time;}public void setTime(double time) {this.time = time;}public double getIcecream() {return icecream;}public void setIcecream(double icecream) {this.icecream = icecream;}public int getType() {return type;}public void setType(int type) {this.type = type;}public double getDistance() {return distance;}public void setDistance(double distance) {this.distance = distance;}/* (non-Javadoc) * @see java.lang.Comparable#compareTo(java.lang.Object) * 這裡進行倒排序 */@Overridepublic int compareTo(Data o) {if (this.distance < o.getDistance()) {return -1;}else if (this.distance > o.getDistance()) {return 1;}return 0;}} 對資料的封裝,為進行之後的排序,實現了comparable介面,重寫compareto方法。
package knn;import java.io.BufferedReader;import java.io.IOException;import java.io.InputStreamReader;import java.util.ArrayList;import java.util.Collections;import java.util.HashMap;import java.util.List;import java.util.Map;public class KNN {private List<Data> dataset = null;public KNN(String fileName) throws IOException {dataset = initDataSet(fileName);}private List<Data> initDataSet(String fileName) throws IOException {List<Data> list = new ArrayList<Data>();BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(KNN.class.getClassLoader().getResourceAsStream(fileName)));String line = null;while ((line = bufferedReader.readLine()) != null) {Data data = new Data();String[] s = line.split("\t");data.setMile(Double.parseDouble(s[0]));data.setTime(Double.parseDouble(s[1]));data.setIcecream(Double.parseDouble(s[2]));if (s[3].equals("largeDoses")) {data.setType(3);} else if (s[3].equals("smallDoses")) {data.setType(2);} else {data.setType(1);}list.add(data);}return list;}/** * 演算法核心 * * @param data * @param dataset * @param k */public int knn(Data data, List<Data> dataset, int k) {for (Data data2 : dataset) {double distance = calDistance(data, data2);data2.setDistance(distance);}// 對距離進行排序,倒序Collections.sort(dataset);// 從前k個樣本中,找到出現頻率最高的類別int type1 = 0, type2 = 0, type3 = 0;for (int i = 0; i < k; i++) {Data d = dataset.get(i);if (d.getType() == 1) {++type1;continue;} else if (d.getType() == 2) {++type2;continue;} else {++type3;}}//System.out.println(type1 + "========" + type2 + "=========" + type3);if (type1 > type2) {if (type1 > type3) {return 1;}else {return 3;}}else {if (type2 > type3) {return 2;}else {return 3;}}}/** * 計算兩個樣本點之間的距離 * * @param data * @param data2 * @return */private double calDistance(Data data, Data data2) {double sum = Math.pow((data.getMile() - data2.getMile()), 2)+ Math.pow((data.getIcecream() - data2.getIcecream()), 2)+ Math.pow((data.getTime() - data2.getTime()), 2);return Math.sqrt(sum);}/** * 將資料集歸一化處理<br> * <br> * newValue = (oldValue - min) / (max - min) * * @param dataset2 * @return */private List<Data> autoNorm(List<Data> oldDataSet) {List<Data> newDataSet = new ArrayList<Data>();// find max and minMap<String, Double> map = findMaxAndMin(oldDataSet);for (Data data : oldDataSet) {data.setMile(calNewValue(data.getMile(),map.get("maxDistance"), map.get("minDistance")));data.setTime(calNewValue(data.getTime(), map.get("maxTime"),map.get("minTime")));data.setIcecream(calNewValue(data.getIcecream(),map.get("maxIcecream"), map.get("minIcecream")));newDataSet.add(data);}return newDataSet;}/** * @param oldValue * @param maxValue * @param minValue * @return newValue = (oldValue - min) / (max - min) */private double calNewValue(double oldValue, double maxValue, double minValue) {return (double)(oldValue - minValue) / (maxValue - minValue);}/** * find the max and the min * * @return */private Map<String, Double> findMaxAndMin(List<Data> oldDataSet) {Map<String, Double> map = new HashMap<String, Double>();double maxDistance = Integer.MIN_VALUE;double minDistance = Integer.MAX_VALUE;double maxTime = Double.MIN_VALUE;double minTime = Double.MAX_VALUE;double maxIcecream = Double.MIN_VALUE;double minIcecream = Double.MAX_VALUE;for (Data data : oldDataSet) {if (data.getMile() > maxDistance) {maxDistance = data.getMile();}if (data.getMile() < minDistance) {minDistance = data.getMile();}if (data.getTime() > maxTime) {maxTime = data.getTime();}if (data.getTime() < minTime) {minTime = data.getTime();}if (data.getIcecream() > maxIcecream) {maxIcecream = data.getIcecream();}if (data.getIcecream() < minIcecream) {minIcecream = data.getIcecream();}}map.put("maxDistance", maxDistance);map.put("minDistance", minDistance);map.put("maxTime", maxTime);map.put("minTime", minTime);map.put("maxIcecream", maxIcecream);map.put("minIcecream", minIcecream);return map;}/** * 將資料集以散佈圖呈現 */public void show() {new ScatterPlotChart().showChart(dataset);}/** * 取已有資料的10%作為測試資料,這裡我們選取100個樣本作為測試樣本,其餘作為訓練樣本 * @throws IOException */public void test() throws IOException {List<Data> testDataSet = initDataSet("test.txt");//歸一化資料List<Data> newTestDataSet = autoNorm(testDataSet);List<Data> newDataSet = autoNorm(dataset);int errorCount = 0;for (Data data : newTestDataSet) {int type = knn(data, newDataSet, 6);if (type != data.getType()) {++errorCount;}}System.out.println("錯誤率:" + (double)errorCount / testDataSet.size() * 100 + "%");}public static void main(String[] args) throws IOException {KNN knn = new KNN("datingTestSet.txt");knn.test();}} 兩個樣本點之間的距離,仍舊採用歐幾裡得計演算法計算,但這之前很重要的一步是,對資料集進行歸一,這也是歐幾裡得演算法不足之處。因為數字差值最大的屬性對計算結果影響很大,也就是說,每年擷取的飛行常客裡程數對於計算結果的影響將遠遠大於其他兩個特徵。因此這裡將資料集映射到[0,1]之間。
最後用測試集中發生錯誤的次數除以總的測試數得到錯誤率,程式啟動並執行結果為百分之八,還湊合吧。
演算法不足之處:KNN演算法是分類資料最簡單最有效演算法之一,但是該演算法必須儲存全部的資料集,如果訓練資料集很大,必須使用很大的儲存空間,此外,由於必須對資料集中的每個資料計算距離,實際使用時可能非常耗時。另一個缺陷是它無法給出任何資料的基礎結構資訊,因此我們也無法知曉平均執行個體樣本和典型樣本具有什麼特徵。
如有什麼問題,歡迎大家和我一起學習交流。