二分Kmeans的java實現

來源:互聯網
上載者:User

標籤:kmeans   機器學習   java   bisecting   

剛剛研究了Kmeans。Kmeans是一種十分簡單的聚類演算法。但是他十分依賴於使用者最初給定的k值。它無法發現任意形狀和大小的簇,最適合於發現球狀簇。他的時間複雜度為O(tkn)。kmeans演算法有兩個核心點:計算距離的公式&判斷迭代停止的條件。一般距採用歐式距離等可以隨意。判斷迭代停止的條件可以有:

1) 每個簇的中心點不再變化則停止迭代

2)所有簇的點與這個簇的中心點的誤差平方和(SSE)的所有簇的總和不再變化

3)設定人為的迭代次數,觀察實驗效果。


當初始簇心選擇不好的時候聚類的效果會很差。所以後來又有一個人提出了二分k均值(bisectingkmeans),其核心思路是:將初始的一個簇一分為二計算出誤差平方和最大的那個簇,對他進行再一次的二分。直至切分的簇的個數為k個停止。 其實質就是不斷的對選中的簇做k=2的kmeans切分。

因為聚類的誤差平方和能夠衡量聚類效能,該值越小表示資料點月接近於它們的質心,聚類效果就越好。所以我們就需要對誤差平方和最大的簇進行再一次的劃分,因為誤差平方和越大,表示該簇聚類越不好,越有可能是多個簇被當成一個簇了,所以我們首先需要對這個簇進行劃分。


下面是代碼,kmeans的原始代碼來源於http://blog.csdn.net/cyxlzzs/article/details/7416491,我稍作了一些修改。


package org.algorithm;import java.util.ArrayList;import java.util.List;/** * 二分k均值,實際上是對一個集合做多次的k=2的kmeans劃分, 每次劃分後會對sse值較大的簇再進行二分。 最終使得或分出來的簇的個數為k個則停止 *  * 這裡利用之前別人寫好的一個kmeans的java實現作為基礎類。 *  * @author l0979365428 *  */public class BisectingKmeans {private int k;// 分成多少簇private List<float[]> dataSet;// 當前要被二分的簇private List<ClusterSet> cluster; // 簇/** * @param args */public static void main(String[] args) {// 初始化一個Kmean對象,將k置為10BisectingKmeans bkm = new BisectingKmeans(5);// 初始化實驗集ArrayList<float[]> dataSet = new ArrayList<float[]>();dataSet.add(new float[] { 1, 2 });dataSet.add(new float[] { 3, 3 });dataSet.add(new float[] { 3, 4 });dataSet.add(new float[] { 5, 6 });dataSet.add(new float[] { 8, 9 });dataSet.add(new float[] { 4, 5 });dataSet.add(new float[] { 6, 4 });dataSet.add(new float[] { 3, 9 });dataSet.add(new float[] { 5, 9 });dataSet.add(new float[] { 4, 2 });dataSet.add(new float[] { 1, 9 });dataSet.add(new float[] { 7, 8 });// 設定未經處理資料集bkm.setDataSet(dataSet);// 執行演算法bkm.execute();// 得到聚類結果// ArrayList<ArrayList<float[]>> cluster = bkm.getCluster();// 查看結果// for (int i = 0; i < cluster.size(); i++) {// bkm.printDataArray(cluster.get(i), "cluster[" + i + "]");// }}public BisectingKmeans(int k) {// 比2還小有啥要劃分的意義麼if (k < 2) {k = 2;}this.k = k;}/** * 設定需分組的未經處理資料集 *  * @param dataSet */public void setDataSet(ArrayList<float[]> dataSet) {this.dataSet = dataSet;}/** * 執行演算法 */public void execute() {long startTime = System.currentTimeMillis();System.out.println("BisectingKmeans begins");BisectingKmeans();long endTime = System.currentTimeMillis();System.out.println("BisectingKmeans running time="+ (endTime - startTime) + "ms");System.out.println("BisectingKmeans ends");System.out.println();}/** * 初始化 */private void init() {int dataSetLength = dataSet.size();if (k > dataSetLength) {k = dataSetLength;}}/** * 初始化簇集合 *  * @return 一個分為k簇的空資料的簇集合 */private ArrayList<ArrayList<float[]>> initCluster() {ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>();for (int i = 0; i < k; i++) {cluster.add(new ArrayList<float[]>());}return cluster;}/** * Kmeans演算法核心過程方法 */private void BisectingKmeans() {init();if (k < 2) {// 小於2 則原樣輸出資料集被認為是只分了一個簇ClusterSet cs = new ClusterSet();cs.setClu(dataSet);cluster.add(cs);}// 調用kmeans進行二分cluster = new ArrayList();while (cluster.size() < k) {List<ClusterSet> clu = kmeans(dataSet);for (ClusterSet cl : clu) {cluster.add(cl);}if (cluster.size() == k)break;else// 順序計算他們的誤差平方和{float maxerro=0f;int maxclustersetindex=0;int i=0;for (ClusterSet tt : cluster) {//計算誤差平方和並得出誤差平方和最大的簇float erroe = CommonUtil.countRule(tt.getClu(), tt.getCenter());tt.setErro(erroe);if(maxerro<erroe){maxerro=erroe;maxclustersetindex=i;}i++;}dataSet=cluster.get(maxclustersetindex).getClu();cluster.remove(maxclustersetindex);}}int i=0;for(ClusterSet sc:cluster){CommonUtil.printDataArray(sc.getClu(),"cluster"+i);i++;}}/** * 調用kmeans得到兩個簇。 *  * @param dataSet * @return */private List<ClusterSet> kmeans(List<float[]> dataSet) {Kmeans k = new Kmeans(2);// 設定未經處理資料集k.setDataSet(dataSet);// 執行演算法k.execute();// 得到聚類結果List<List<float[]>> clus = k.getCluster();List<ClusterSet> clusterset = new ArrayList<ClusterSet>();int i = 0;for (List<float[]> cl : clus) {ClusterSet cs = new ClusterSet();cs.setClu(cl);cs.setCenter(k.getCenter().get(i));clusterset.add(cs);i++;}return clusterset;}class ClusterSet {private float erro;private List<float[]> clu;private float[] center;public float getErro() {return erro;}public void setErro(float erro) {this.erro = erro;}public List<float[]> getClu() {return clu;}public void setClu(List<float[]> clu) {this.clu = clu;}public float[] getCenter() {return center;}public void setCenter(float[] center) {this.center = center;}}}

package org.algorithm;import java.util.List;/** * 把計算距離和誤差的公式抽離出來 * @author l0979365428 * */public class CommonUtil {/** * 計算兩個點之間的距離 *  * @param element *            點1 * @param center *            點2 * @return 距離 */public static  float distance(float[] element, float[] center) {float distance = 0.0f;float x = element[0] - center[0];float y = element[1] - center[1];float z = x * x + y * y;distance = (float) Math.sqrt(z);return distance;}/** * 求兩點誤差平方的方法 *  * @param element *            點1 * @param center *            點2 * @return 誤差平方 */public static  float errorSquare(float[] element, float[] center) {float x = element[0] - center[0];float y = element[1] - center[1];float errSquare = x * x + y * y;return errSquare;}/** * 計算誤差平方和準則函數方法 */public static  float countRule( List<float[]> cluster,float[] center) {float jcF = 0;for (int j = 0; j < cluster.size(); j++) {jcF += CommonUtil.errorSquare(cluster.get(j), center);}return  jcF;}/** * 列印資料,測試用 *  * @param dataArray *            資料集 * @param dataArrayName *            資料集名稱 */public static  void printDataArray(List<float[]> dataArray, String dataArrayName) {for (int i = 0; i < dataArray.size(); i++) {System.out.println("print:" + dataArrayName + "[" + i + "]={"+ dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");}System.out.println("===================================");}}

package org.algorithm;import java.util.ArrayList;import java.util.List;import java.util.Random;/** * K均值聚類演算法 */public class Kmeans {private int k;// 分成多少簇private int m;// 迭代次數private int dataSetLength;// 資料集元素個數,即資料集的長度private List<float[]> dataSet;// 資料集鏈表private List<float[]> center;// 中心鏈表private List<List<float[]>> cluster; // 簇private List<Float> jc;// 誤差平方和,k越接近dataSetLength,誤差越小private Random random;public static void main(String[] args) {// 初始化一個Kmean對象,將k置為10Kmeans k = new Kmeans(5);// 初始化實驗集ArrayList<float[]> dataSet = new ArrayList<float[]>();dataSet.add(new float[] { 1, 2 });dataSet.add(new float[] { 3, 3 });dataSet.add(new float[] { 3, 4 });dataSet.add(new float[] { 5, 6 });dataSet.add(new float[] { 8, 9 });dataSet.add(new float[] { 4, 5 });dataSet.add(new float[] { 6, 4 });dataSet.add(new float[] { 3, 9 });dataSet.add(new float[] { 5, 9 });dataSet.add(new float[] { 4, 2 });dataSet.add(new float[] { 1, 9 });dataSet.add(new float[] { 7, 8 });// 設定未經處理資料集k.setDataSet(dataSet);// 執行演算法k.execute();// 得到聚類結果List<List<float[]>> cluster = k.getCluster();// 查看結果for (int i = 0; i < cluster.size(); i++) {CommonUtil.printDataArray(cluster.get(i), "cluster[" + i + "]");}}/** * 設定需分組的未經處理資料集 *  * @param dataSet */public void setDataSet(List<float[]> dataSet) {this.dataSet = dataSet;}/** * 擷取結果分組 *  * @return 結果集 */public List<List<float[]>> getCluster() {return cluster;}/** * 建構函式,傳入需要分成的簇數量 *  * @param k *            簇數量,若k<=0時,設定為1,若k大於資料來源的長度時,置為資料來源的長度 */public Kmeans(int k) {if (k <= 0) {k = 1;}this.k = k;}/** * 初始化 */private void init() {m = 0;random = new Random();if (dataSet == null || dataSet.size() == 0) {initDataSet();}dataSetLength = dataSet.size();if (k > dataSetLength) {k = dataSetLength;}center = initCenters();cluster = initCluster();jc = new ArrayList<Float>();}/** * 如果調用者未初始化資料集,則採用自我裝載資料集 */private void initDataSet() {dataSet = new ArrayList<float[]>();// 其中{6,3}是一樣的,所以長度為15的資料集分成14簇和15簇的誤差都為0float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5 },{ 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 },{ 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } };for (int i = 0; i < dataSetArray.length; i++) {dataSet.add(dataSetArray[i]);}}/** * 初始化中心資料鏈表,分成多少簇就有多少個中心點 *  * @return 中心點集 */private ArrayList<float[]> initCenters() {ArrayList<float[]> center = new ArrayList<float[]>();int[] randoms = new int[k];boolean flag;int temp = random.nextInt(dataSetLength);randoms[0] = temp;for (int i = 1; i < k; i++) {flag = true;while (flag) {temp = random.nextInt(dataSetLength);int j = 0;while (j < i) {if (temp == randoms[j]) {break;}j++;}if (j == i) {flag = false;}}randoms[i] = temp;}for (int i = 0; i < k; i++) {center.add(dataSet.get(randoms[i]));// 產生初始化中心鏈表}return center;}/** * 初始化簇集合 *  * @return 一個分為k簇的空資料的簇集合 */private List<List<float[]>> initCluster() {List<List<float[]>> cluster = new ArrayList();for (int i = 0; i < k; i++) {cluster.add(new ArrayList<float[]>());}return cluster;}/** * 擷取距離集合中最小距離的位置 *  * @param distance *            距離數組 * @return 最小距離在距離數組中的位置 */private int minDistance(float[] distance) {float minDistance = distance[0];int minLocation = 0;for (int i = 1; i < distance.length; i++) {if (distance[i] < minDistance) {minDistance = distance[i];minLocation = i;} else if (distance[i] == minDistance) // 如果相等,隨機返回一個位置{if (random.nextInt(10) < 5) {minLocation = i;}}}return minLocation;}/** * 核心,將當前元素放到最小距離中心相關的簇中 */private void clusterSet() {float[] distance = new float[k];for (int i = 0; i < dataSetLength; i++) {for (int j = 0; j < k; j++) {distance[j] = CommonUtil.distance(dataSet.get(i), center.get(j));}int minLocation = minDistance(distance);cluster.get(minLocation).add(dataSet.get(i));// 核心,將當前元素放到最小距離中心相關的簇中}}/** * 計算誤差平方和準則函數方法 */private void countRule() {float jcF = 0;for (int i = 0; i < cluster.size(); i++) {for (int j = 0; j < cluster.get(i).size(); j++) {jcF += CommonUtil.errorSquare(cluster.get(i).get(j), center.get(i));}}jc.add(jcF);}/** * 設定新的簇中心方法 */private void setNewCenter() {for (int i = 0; i < k; i++) {int n = cluster.get(i).size();if (n != 0) {float[] newCenter = { 0, 0 };for (int j = 0; j < n; j++) {newCenter[0] += cluster.get(i).get(j)[0];newCenter[1] += cluster.get(i).get(j)[1];}// 設定一個平均值newCenter[0] = newCenter[0] / n;newCenter[1] = newCenter[1] / n;center.set(i, newCenter);}}}public List<float[]> getCenter() {return center;}public void setCenter(List<float[]> center) {this.center = center;}/** * Kmeans演算法核心過程方法 */private void kmeans() {init();// 迴圈分組,直到誤差不變為止while (true) {clusterSet();countRule();if (m != 0) {if (jc.get(m) - jc.get(m - 1) == 0) {break;}}setNewCenter();m++;cluster.clear();cluster = initCluster();}}/** * 執行演算法 */public void execute() {long startTime = System.currentTimeMillis();System.out.println("kmeans begins");kmeans();long endTime = System.currentTimeMillis();System.out.println("kmeans running time=" + (endTime - startTime)+ "ms");System.out.println("kmeans ends");System.out.println();}}

分別執行兩種聚類演算法都使得k=5結果如下:

Kmeans:

print:cluster[0]={5.0,6.0}print:cluster[1]={4.0,5.0}print:cluster[2]={6.0,4.0}===================================print:cluster[0]={1.0,2.0}print:cluster[1]={3.0,3.0}print:cluster[2]={3.0,4.0}print:cluster[3]={4.0,2.0}===================================print:cluster[0]={7.0,8.0}===================================print:cluster[0]={8.0,9.0}===================================print:cluster[0]={3.0,9.0}print:cluster[1]={5.0,9.0}print:cluster[2]={1.0,9.0}===================================

BisectingKmeans:
print:cluster0[0]={8.0,9.0}print:cluster0[1]={7.0,8.0}===================================print:cluster1[0]={3.0,4.0}print:cluster1[1]={5.0,6.0}print:cluster1[2]={4.0,5.0}print:cluster1[3]={6.0,4.0}===================================print:cluster2[0]={1.0,2.0}print:cluster2[1]={3.0,3.0}print:cluster2[2]={4.0,2.0}===================================print:cluster3[0]={1.0,9.0}===================================print:cluster4[0]={3.0,9.0}print:cluster4[1]={5.0,9.0}===================================

如上有理解問題還請指正。



參考文獻:

http://blog.csdn.net/zouxy09/article/details/17590137

http://wenku.baidu.com/link?url=e6sXeX_txPMnNnYy8W28mP-HSD2Lk8cQGbW-4esipqu95r-P4Ke2QPeHLhfBtoie6agplav6VtVwxlyg-jf_5byHJ_Ce93ARqA6U9rn6XKK

《機器學習實戰》

著作權聲明:本文為博主原創文章,未經博主允許不得轉載。

二分Kmeans的java實現

聯繫我們

該頁面正文內容均來源於網絡整理,並不代表阿里雲官方的觀點,該頁面所提到的產品和服務也與阿里云無關,如果該頁面內容對您造成了困擾,歡迎寫郵件給我們,收到郵件我們將在5個工作日內處理。

如果您發現本社區中有涉嫌抄襲的內容,歡迎發送郵件至: info-contact@alibabacloud.com 進行舉報並提供相關證據,工作人員會在 5 個工作天內聯絡您,一經查實,本站將立刻刪除涉嫌侵權內容。

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.