網上看一個達人用java寫的一元線性迴歸的實現,我覺得挺有用的,一些企業做資料採礦不是用到了,預測運營收入的功能嗎?採用一元線性迴歸演算法,可以計算出類似的功能。直接上代碼吧:
1、定義一個DataPoint類,對X和Y座標點進行封裝:
/** * File : DataPoint.java * Author : zhouyujie * Date : 2012-01-11 16:00:00 * Description : Java實現一元線性迴歸的演算法,座標點實體類,(可實現統計指標的預測) */package com.zyujie.dm;public class DataPoint {/** the x value */public float x;/** the y value */public float y;/** * Constructor. * * @param x * the x value * @param y * the y value */public DataPoint(float x, float y) {this.x = x;this.y = y;}}
2、下面是演算法實現迴歸線:
/** * File : DataPoint.java * Author : zhouyujie * Date : 2012-01-11 16:00:00 * Description : Java實現一元線性迴歸的演算法,迴歸線實作類別,(可實現統計指標的預測) */package com.zyujie.dm;import java.math.BigDecimal;import java.util.ArrayList;public class RegressionLine // implements Evaluatable{/** sum of x */private double sumX;/** sum of y */private double sumY;/** sum of x*x */private double sumXX;/** sum of x*y */private double sumXY;/** sum of y*y */private double sumYY;/** sum of yi-y */private double sumDeltaY;/** sum of sumDeltaY^2 */private double sumDeltaY2;/** 誤差 */private double sse;private double sst;private double E;private String[] xy;private ArrayList listX;private ArrayList listY;private int XMin, XMax, YMin, YMax;/** line coefficient a0 */private float a0;/** line coefficient a1 */private float a1;/** number of data points */private int pn;/** true if coefficients valid */private boolean coefsValid;/** * Constructor. */public RegressionLine() {XMax = 0;YMax = 0;pn = 0;xy = new String[2];listX = new ArrayList();listY = new ArrayList();}/** * Constructor. * * @param data * the array of data points */public RegressionLine(DataPoint data[]) {pn = 0;xy = new String[2];listX = new ArrayList();listY = new ArrayList();for (int i = 0; i < data.length; ++i) {addDataPoint(data[i]);}}/** * Return the current number of data points. * * @return the count */public int getDataPointCount() {return pn;}/** * Return the coefficient a0. * * @return the value of a0 */public float getA0() {validateCoefficients();return a0;}/** * Return the coefficient a1. * * @return the value of a1 */public float getA1() {validateCoefficients();return a1;}/** * Return the sum of the x values. * * @return the sum */public double getSumX() {return sumX;}/** * Return the sum of the y values. * * @return the sum */public double getSumY() {return sumY;}/** * Return the sum of the x*x values. * * @return the sum */public double getSumXX() {return sumXX;}/** * Return the sum of the x*y values. * * @return the sum */public double getSumXY() {return sumXY;}public double getSumYY() {return sumYY;}public int getXMin() {return XMin;}public int getXMax() {return XMax;}public int getYMin() {return YMin;}public int getYMax() {return YMax;}/** * Add a new data point: Update the sums. * * @param dataPoint * the new data point */public void addDataPoint(DataPoint dataPoint) {sumX += dataPoint.x;sumY += dataPoint.y;sumXX += dataPoint.x * dataPoint.x;sumXY += dataPoint.x * dataPoint.y;sumYY += dataPoint.y * dataPoint.y;if (dataPoint.x > XMax) {XMax = (int) dataPoint.x;}if (dataPoint.y > YMax) {YMax = (int) dataPoint.y;}// 把每個點的具體座標存入ArrayList中,備用xy[0] = (int) dataPoint.x + "";xy[1] = (int) dataPoint.y + "";if (dataPoint.x != 0 && dataPoint.y != 0) {System.out.print(xy[0] + ",");System.out.println(xy[1]);try {// System.out.println("n:"+n);listX.add(pn, xy[0]);listY.add(pn, xy[1]);} catch (Exception e) {e.printStackTrace();}/* * System.out.println("N:" + n); System.out.println("ArrayList * listX:"+ listX.get(n)); System.out.println("ArrayList listY:"+ * listY.get(n)); */}++pn;coefsValid = false;}/** * Return the value of the regression line function at x. (Implementation of * Evaluatable.) * * @param x * the value of x * @return the value of the function at x */public float at(int x) {if (pn < 2)return Float.NaN;validateCoefficients();return a0 + a1 * x;}/** * Reset. */public void reset() {pn = 0;sumX = sumY = sumXX = sumXY = 0;coefsValid = false;}/** * Validate the coefficients. 計算方程係數 y=ax+b 中的a */private void validateCoefficients() {if (coefsValid)return;if (pn >= 2) {float xBar = (float) sumX / pn;float yBar = (float) sumY / pn;a1 = (float) ((pn * sumXY - sumX * sumY) / (pn * sumXX - sumX* sumX));a0 = (float) (yBar - a1 * xBar);} else {a0 = a1 = Float.NaN;}coefsValid = true;}/** * 返回誤差 */public double getR() {// 遍曆這個list並計算分母for (int i = 0; i < pn - 1; i++) {float Yi = (float) Integer.parseInt(listY.get(i).toString());float Y = at(Integer.parseInt(listX.get(i).toString()));float deltaY = Yi - Y;float deltaY2 = deltaY * deltaY;/* * System.out.println("Yi:" + Yi); System.out.println("Y:" + Y); * System.out.println("deltaY:" + deltaY); * System.out.println("deltaY2:" + deltaY2); */sumDeltaY2 += deltaY2;// System.out.println("sumDeltaY2:" + sumDeltaY2);}sst = sumYY - (sumY * sumY) / pn;// System.out.println("sst:" + sst);E = 1 - sumDeltaY2 / sst;return round(E, 4);}// 用於實現精確的四捨五入public double round(double v, int scale) {if (scale < 0) {throw new IllegalArgumentException("The scale must be a positive integer or zero");}BigDecimal b = new BigDecimal(Double.toString(v));BigDecimal one = new BigDecimal("1");return b.divide(one, scale, BigDecimal.ROUND_HALF_UP).doubleValue();}public float round(float v, int scale) {if (scale < 0) {throw new IllegalArgumentException("The scale must be a positive integer or zero");}BigDecimal b = new BigDecimal(Double.toString(v));BigDecimal one = new BigDecimal("1");return b.divide(one, scale, BigDecimal.ROUND_HALF_UP).floatValue();}}
3、線性迴歸測試類:
/** * File : DataPoint.java * Author : zhouyujie * Date : 2012-01-11 16:00:00 * Description : Java實現一元線性迴歸的演算法,線性迴歸測試類,(可實現統計指標的預測) */package com.zyujie.dm;/** * <p> * <b>Linear Regression</b> <br> * Demonstrate linear regression by constructing the regression line for a set * of data points. * * <p> * require DataPoint.java,RegressionLine.java * * <p> * 為了計算對於給定資料點的最小方差回線,需要計算SumX,SumY,SumXX,SumXY; (註:SumXX = Sum (X^2)) * <p> * <b>迴歸直線方程如下: f(x)=a1x+a0 </b> * <p> * <b>斜率和截距的計算公式如下:</b> <br> * n: 資料點個數 * <p> * a1=(n(SumXY)-SumX*SumY)/(n*SumXX-(SumX)^2) <br> * a0=(SumY - SumY * a1)/n <br> * (也可表達為a0=averageY-a1*averageX) * * <p> * <b>畫線的原理:兩點成一直線,只要能確定兩個點即可</b><br> * 第一點:(0,a0) 再隨意取一個x1值代入方程,取得y1,連結(0,a0)和(x1,y1)兩點即可。 * 為了讓線穿過整個圖,x1可以取橫座標的最大值Xmax,即兩點為(0,a0),(Xmax,Y)。如果y=a1*Xmax+a0,y大於 * 縱座標最大值Ymax,則不用這個點。改用y取最大值Ymax,算得此時x的值,使用(X,Ymax), 即兩點為(0,a0),(X,Ymax) * * <p> * <b>擬合度計算:(即Excel中的R^2)</b> * <p> * *R2 = 1 - E * <p> * 誤差E的計算:E = SSE/SST * <p> * SSE=sum((Yi-Y)^2) SST=sumYY - (sumY*sumY)/n; * <p> */public class LinearRegression {private static final int MAX_POINTS = 10;private double E;/** * Main program. * * @param args * the array of runtime arguments */public static void main(String args[]) {RegressionLine line = new RegressionLine();line.addDataPoint(new DataPoint(1, 136));line.addDataPoint(new DataPoint(2, 143));line.addDataPoint(new DataPoint(3, 132));line.addDataPoint(new DataPoint(4, 142));line.addDataPoint(new DataPoint(5, 147));printSums(line);printLine(line);}/** * Print the computed sums. * * @param line * the regression line */private static void printSums(RegressionLine line) {System.out.println("\n資料點個數 n = " + line.getDataPointCount());System.out.println("\nSum x = " + line.getSumX());System.out.println("Sum y = " + line.getSumY());System.out.println("Sum xx = " + line.getSumXX());System.out.println("Sum xy = " + line.getSumXY());System.out.println("Sum yy = " + line.getSumYY());}/** * Print the regression line function. * * @param line * the regression line */private static void printLine(RegressionLine line) {System.out.println("\n迴歸線公式: y = " + line.getA1() + "x + "+ line.getA0());System.out.println("誤差: R^2 = " + line.getR());}//y = 2.1x + 133.7 2.1 * 6 + 133.7 = 12.6 + 133.7 = 146.3//y = 2.1x + 133.7 2.1 * 7 + 133.7 = 14.7 + 133.7 = 148.4}
我們運行測試類別,得到運行結果:
1,136
2,143
3,132
4,142
5,147
資料點個數 n = 5
Sum x = 15.0
Sum y = 700.0
Sum xx = 55.0
Sum xy = 2121.0
Sum yy = 98142.0
迴歸線公式: y = 2.1x + 133.7
誤差: R^2 = 0.3658
假如某公司:
1月收入,136萬元
2月收入,143萬元
3月收入,132萬元
4月收入,142萬元
5月收入,147萬元
我們可以根據迴歸線公式:y = 2.1x + 133.7,預測出6月份收入:
y = 2.1 * 6 + 133.7 = 12.6 + 133.7 = 146.3