python用K近鄰(KNN)演算法分類MNIST資料集和Fashion MNIST資料集

來源:互聯網
上載者:User

標籤:getter   col   err   array   屬性   orm   分析   簡單   [1]   

一、KNN演算法的介紹

  K最近鄰(k-Nearest Neighbor,KNN)分類演算法是最簡單的機器學習演算法之一,理論上比較成熟。KNN演算法首先將待分類樣本表達成和訓練樣本一致的特徵向量;然後根據距離計算待測試樣本和每個訓練樣本的距離,選擇距離最小的K個樣本作為近鄰樣本;最後根據K個近鄰樣本判斷待分類樣本的類別。KNN演算法的正確選取是分類正確的關鍵因素之一,而近鄰樣本是通過計算測試樣本與每個訓練集樣本的距離來選定的,故定義合適的距離是KNN正確分類的前提。

本文中在上述研究的基礎上,將特徵屬性值對類別判斷的重要性視為同樣重要,將樣本距離重新定義為任意兩樣本間像素點間的相關距離,並且距離計算使用的是距離。

二、演算法原理

  k-近鄰演算法(KNN),其工作原理是存在一個樣本資料集合,也稱作訓練樣本集,並且樣本集中每個資料都存在標籤,即我們知道樣本集中每一資料與所屬分類的對應關係。輸入沒有標籤的新資料後,將新資料的每個特徵與樣本集中資料資料對應的特徵進行比較,然後演算法提取樣本集中特徵最相似資料(最近鄰)的分類標籤。一般來說,我們只選擇樣本資料集中前k個最相似的資料,這就是k-近鄰演算法中k的出處,通常k是不大於20的整數。最後,選擇k個最相似資料中出現次數最多的分類,作為新資料的分類。

  收集和準備資料,這裡使用的是mnist資料集和fashion mnist資料集,輸入樣本資料和結構化的輸出結果,可以調整k的值,然後運行k-近鄰演算法判斷輸入資料分別屬於哪個分類,最後計算錯誤率和準確率。

KNN演算法(k鄰近演算法分類演算法),就是k個最近的鄰居的,說的是每個樣本都可以用它最接近的k個鄰居來代表,核心思想是如果一個樣本在特徵空間中的k個最相鄰的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別,並具有這個類別上樣本的特性。KNN演算法不僅可以用於分類,還可以用於迴歸。通過找出一個樣本的k個最近鄰居,將這些鄰居的屬性的平均值賦給該樣本,就可以得到該樣本的屬性。在KNN中,通過計算對象間距離來作為各個對象之間的非相似性指標,避免了對象之間的匹配問題,在這裡距離使用的是歐氏距離。

詳細實現:將mnist資料集和fashion mnist資料集包括訓練集和驗證集匯入到工程檔案中,接著計算驗證集和訓練集的距離,並從小到達排序得到距離最近的k個鄰居,並通過投票得到所屬類別最高的類別,並判斷該驗證集的圖片屬於該類別,接著講該類別的標籤和驗證集的標籤進行比對,如果相符合則是正確的,如果不相符合,則是屬於出錯,最後輸出計算出的錯誤率和準確率。

三、資料集介紹  MNIST資料集,訓練集60000張圖片和標籤;測試集有10000張圖片和標籤。讀取28*28圖片以後,要將每張圖片轉換為1*784的向量。四、KNN演算法實現和結果分析代碼實現:
from numpy import *
import operator
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from os import listdir
from mpl_toolkits.mplot3d import Axes3D
import struct

#讀取圖片
def read_image(file_name):
#先用二進位方式把檔案都讀進來
file_handle=open(file_name,"rb") #以二進位開啟文檔
file_content=file_handle.read() #讀取到緩衝區中

offset=0
head = struct.unpack_from(‘>IIII‘, file_content, offset) # 取前4個整數,返回一個元組
offset += struct.calcsize(‘>IIII‘)
imgNum = head[1] #圖片數
rows = head[2] #寬度
cols = head[3] #高度
# print(imgNum)
# print(rows)
# print(cols)

#測試讀取一個圖片是否讀取成功
#im = struct.unpack_from(‘>784B‘, file_content, offset)
#offset += struct.calcsize(‘>784B‘)

images=np.empty((imgNum , 784))#empty,是它所常見的數組內的所有元素均為空白,沒有實際意義,它是建立數組最快的方法
image_size=rows*cols#單個圖片的大小
fmt=‘>‘ + str(image_size) + ‘B‘#單個圖片的format

for i in range(imgNum):
images[i] = np.array(struct.unpack_from(fmt, file_content, offset))
# images[i] = np.array(struct.unpack_from(fmt, file_content, offset)).reshape((rows, cols))
offset += struct.calcsize(fmt)
return images

‘‘‘bits = imgNum * rows * cols # data一共有60000*28*28個像素值
bitsString = ‘>‘ + str(bits) + ‘B‘ # fmt格式:‘>47040000B‘
imgs = struct.unpack_from(bitsString, file_content, offset) # 取data資料,返回一個元組
imgs_array=np.array(imgs).reshape((imgNum,rows*cols)) #最後將讀取的資料reshape成 【圖片數,圖片像素】二維數組
return imgs_array‘‘‘

#讀取標籤
def read_label(file_name):
file_handle = open(file_name, "rb") # 以二進位開啟文檔
file_content = file_handle.read() # 讀取到緩衝區中

head = struct.unpack_from(‘>II‘, file_content, 0) # 取前2個整數,返回一個元組
offset = struct.calcsize(‘>II‘)

labelNum = head[1] # label數
# print(labelNum)
bitsString = ‘>‘ + str(labelNum) + ‘B‘ # fmt格式:‘>47040000B‘
label = struct.unpack_from(bitsString, file_content, offset) # 取data資料,返回一個元組
return np.array(label)

#KNN演算法
def KNN(test_data, dataSet, labels, k):
dataSetSize = dataSet.shape[0]#dataSet.shape[0]表示的是讀取矩陣第一維度長度,代表行數
# distance1 = tile(test_data, (dataSetSize,1)) - dataSet#歐氏距離計算開始
# print("dataSetSize:")
# print(dataSetSize)
distance1 = tile(test_data, (dataSetSize)).reshape((60000,784))-dataSet#tile函數在行上重複dataSetSizec次,在列上重複1次
# print("distance1.shape")
# print(distance1.shape)
distance2 = distance1**2 #每個元素平方
distance3 = distance2.sum(axis=1)#矩陣每行相加
distances4 = distance3**0.5#歐氏距離計算結束
# print(distances4[53843])
# print(distances4[38620])
# print(distances4[16186])
sortedDistIndicies = distances4.argsort() #返回從小到大排序的索引
classCount=np.zeros((10), np.int32)#10是代表10個類別
for i in range(k): #統計前k個資料類的數量
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] += 1
max = 0
id = 0
print(classCount.shape[0])
# print(classCount.shape[1])

for i in range(classCount.shape[0]):
if classCount[i] >= max:
max = classCount[i]
id = i
print(id)

# sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)#從大到小按類別數目排序
return id

def test_KNN():
# 檔案擷取
#mnist資料集
# train_image = "F:\mnist\\train-images-idx3-ubyte"
# test_image = "F:\mnist\\t10k-images-idx3-ubyte"
# train_label = "F:\mnist\\train-labels-idx1-ubyte"
# test_label = "F:\mnist\\t10k-labels-idx1-ubyte"
#fashion mnist資料集
train_image = "train-images-idx3-ubyte"
test_image = "t10k-images-idx3-ubyte"
train_label = "train-labels-idx1-ubyte"
test_label = "t10k-labels-idx1-ubyte"
# 讀取資料
train_x = read_image(train_image) # train_dataSet
test_x = read_image(test_image) # test_dataSet
train_y = read_label(train_label) # train_label
test_y = read_label(test_label) # test_label

# print(train_x.shape)
# print(test_x.shape)
# print(train_y.shape)
# print(test_y.shape)
# plt.imshow(train_x[0])
# plt.show()

testRatio = 1 # 取資料集的前0.1為測試資料,這個參數比重可以改變
train_row = train_x.shape[0] # 資料集的行數,即資料集的總的樣本數
test_row=test_x.shape[0]
testNum = int(test_row * testRatio)
errorCount = 0 # 判斷錯誤的個數
for i in range(testNum):
result = KNN(test_x[i], train_x, train_y, 30)
# print(‘返回的結果是: %s, 真實結果是: %s‘ % (result, train_y[i]))

print(result, test_y[i])
if result != test_y[i]:
errorCount += 1.0# 如果mnist驗證集的標籤和本身標籤不一樣,則出錯
error_rate = errorCount / float(testNum) # 計算出錯率
acc = 1.0 - error_rate
print(errorCount)
print("\nthe total number of errors is: %d" % errorCount)
print("\nthe total error rate is: %f" % (error_rate))
print("\nthe total accuracy rate is: %f" % (acc))

if __name__ == "__main__":
test_KNN()#test()函數中調用了讀取資料集的函數,並調用分類函數對資料集進行分類,最後對分類情況進行計算
結果分析:

 

輸入:mnist資料集或者fashion mnist資料集

輸出:出錯率和準確率

Mnist資料集:

取k=30,驗證集是50個的時候,準確率是1;

取k=30,驗證集是500個的時候,準確率是0.98;

取k=30,驗證集是10000個的時候,準確率是0.84。

Fashion Mnist資料集

K=30,驗證集是10000的時候,一共的出錯個數是1666,準確率是0.8334。

本文中的資料集採用KNN演算法得到了較高的準確率,但是本文中考慮特徵屬性值對類別判斷的重要性一樣,改進演算法時應該考慮特徵屬性值對類別判斷的重要性不同,兩樣本間屬性的相關距離可以用來度量屬性值對類別的重要性,相關距離熵越小,兩樣本的相似程度越大,類可信度越大;此外本文中應該對不同取值的k進行分別的實驗,得到使準確率較高的k,同時在實驗多個k的時候,可以採用多線程進行跑實驗,縮短時間。



 

python用K近鄰(KNN)演算法分類MNIST資料集和Fashion MNIST資料集

相關文章

聯繫我們

該頁面正文內容均來源於網絡整理,並不代表阿里雲官方的觀點,該頁面所提到的產品和服務也與阿里云無關,如果該頁面內容對您造成了困擾,歡迎寫郵件給我們,收到郵件我們將在5個工作日內處理。

如果您發現本社區中有涉嫌抄襲的內容,歡迎發送郵件至: info-contact@alibabacloud.com 進行舉報並提供相關證據,工作人員會在 5 個工作天內聯絡您,一經查實,本站將立刻刪除涉嫌侵權內容。

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.