標籤:name imp packages matlab csv 線性 tar res square
文章:
http://python.jobbole.com/81215/
python的函數庫好強大!看完這篇博再也不會用matlab了~~
這篇文章使用【panda】讀取csv的資料,使用【sklearn】中的linear_model訓練模型並進行線性預測,使用【matplotlib】將擬合的情況用圖表示出來。
下面的表格是用於訓練模型的表格:
代碼如下:
# -*- coding: utf-8 -*-‘‘‘Created on 2016/11/26@author: chensi‘‘‘# Required Packagesimport matplotlib.pyplot as pltimport numpy as npimport pandas as pdfrom sklearn import datasets, linear_modelfrom numpy.ma.core import getdata# Function to get datadef get_data(file_name): data = pd.read_excel(file_name) X_parameter = [] Y_parameter = [] for single_square_feet ,single_price_value in zip(data[‘square_feet‘],data[‘price‘]): X_parameter.append([float(single_square_feet)]) Y_parameter.append(float(single_price_value)) return X_parameter,Y_parameter# Function for Fitting our data to Linear modeldef linear_model_main(X_parameters,Y_parameters,predict_value):# Create linear regression object regr = linear_model.LinearRegression() regr.fit(X_parameters, Y_parameters) predict_outcome = regr.predict(predict_value) predictions = {} predictions[‘intercept‘] = regr.intercept_ predictions[‘coefficient‘] = regr.coef_ predictions[‘predicted_value‘] = predict_outcome return predictions# Function to show the resutls of linear fit modeldef show_linear_line(X_parameters,Y_parameters):# Create linear regression object regr = linear_model.LinearRegression() regr.fit(X_parameters, Y_parameters) plt.scatter(X_parameters,Y_parameters,color=‘blue‘) plt.plot(X_parameters,regr.predict(X_parameters),color=‘red‘,linewidth=4) plt.xticks(()) plt.yticks(()) plt.show()#---------Test---------------#----------------------------x,y = get_data("g:/input_data.csv")show_linear_line(x,y)print(linear_model_main(x,y,150))#----------------------------#----------------------------
輸出的圖:
例子二:
代碼:
# -*- coding: utf-8 -*-‘‘‘Created on 2016/11/26 @author: chensi‘‘‘# Required Packagesimport csvimport sysimport matplotlib.pyplot as pltimport numpy as npimport pandas as pdfrom sklearn import datasets, linear_model # Function to get datadef get_data(file_name): data = pd.read_excel(file_name) flash_x_parameter = [] flash_y_parameter = [] arrow_x_parameter = [] arrow_y_parameter = [] for x1,y1,x2,y2 in zip(data[‘flash_episode_number‘],data[‘flash_us_viewers‘],data[‘arrow_episode_number‘],data[‘arrow_us_viewers‘]): flash_x_parameter.append([float(x1)]) flash_y_parameter.append(float(y1)) arrow_x_parameter.append([float(x2)]) arrow_y_parameter.append(float(y2)) return flash_x_parameter,flash_y_parameter,arrow_x_parameter,arrow_y_parameter # Function to know which Tv show will have more viewersdef more_viewers(x1,y1,x2,y2): regr1 = linear_model.LinearRegression() regr1.fit(x1, y1) predicted_value1 = regr1.predict(9) print(predicted_value1) regr2 = linear_model.LinearRegression() regr2.fit(x2, y2) predicted_value2 = regr2.predict(9)#print predicted_value1#print predicted_value2 if predicted_value1 > predicted_value2: print ("The Flash Tv Show will have more viewers for next week") else: print ("Arrow Tv Show will have more viewers for next week") x1,y1,x2,y2 = get_data(‘G:/input_data_2.xlsx‘)#print x1,y1,x2,y2more_viewers(x1,y1,x2,y2)
輸出:
python-[panda]-[sklearn]-[matplotlib]-線性預測