python讀取MNIST image資料

來源:互聯網
上載者:User

標籤:gic   title   資料集   zip   matrix   oms   AC   UI   mat   

Lecun Mnist資料集下載

import numpy as npimport structdef loadImageSet(which=0):    print "load image set"    binfile=None    if which==0:        binfile = open("..//dataset//train-images-idx3-ubyte", ‘rb‘)    else:        binfile=  open("..//dataset//t10k-images-idx3-ubyte", ‘rb‘)    buffers = binfile.read()    head = struct.unpack_from(‘>IIII‘ , buffers ,0)    print "head,",head    offset=struct.calcsize(‘>IIII‘)    imgNum=head[1]    width=head[2]    height=head[3]    #[60000]*28*28    bits=imgNum*width*height    bitsString=‘>‘+str(bits)+‘B‘ #like ‘>47040000B‘    imgs=struct.unpack_from(bitsString,buffers,offset)    binfile.close()    imgs=np.reshape(imgs,[imgNum,width,height])    print "load imgs finished"    return imgsdef loadLabelSet(which=0):    print "load label set"    binfile=None    if which==0:        binfile = open("..//dataset//train-labels-idx1-ubyte", ‘rb‘)    else:        binfile=  open("..//dataset//t10k-labels-idx1-ubyte", ‘rb‘)    buffers = binfile.read()    head = struct.unpack_from(‘>II‘ , buffers ,0)    print "head,",head    imgNum=head[1]    offset = struct.calcsize(‘>II‘)    numString=‘>‘+str(imgNum)+"B"    labels= struct.unpack_from(numString , buffers , offset)    binfile.close()    labels=np.reshape(labels,[imgNum,1])    #print labels    print ‘load label finished‘    return labelsif __name__=="__main__":    imgs=loadImageSet()    #import PlotUtil as pu    #pu.showImgMatrix(imgs[0])    loadLabelSet()

及方便訓練的reader

import numpy as npimport structimport gzipimport cPickleclass MnistReader():    def __init__(self,mnist_path,data_dim=1,one_hot=True):        ‘‘‘        mnist_path: the path of mnist.pkl.gz        data_dim=1 [N,784]        data_dim=3 [N,28,28,1]        one_hot: one hot encoding(like: [0,1,0,0,0,0,0,0,0,0]) if true        ‘‘‘        self.mnist_path=mnist_path        self.data_dim=data_dim        self.one_hot=one_hot        self.load_minist(mnist_path)        self.train_datalabel=zip(self.train_x,self.train_y)        self.valid_datalabel=zip(self.valid_x,self.valid_y)        self.batch_offset_train=0    def next_batch_train(self,batch_size):        ‘‘‘        return list of images with shape [N,784] or [N,28,28,1] dependents on self.data_dim               and list of labels with shape [N] or [N,10] dependents on self.one_hot        ‘‘‘        if self.batch_offset_train<len(self.train_datalabel)//batch_size:            imgs=list();labels=list()            for d,l in self.train_datalabel[self.batch_offset_train:self.batch_offset_train+batch_size]:                if self.data_dim==3:                    d=np.reshape(d, [28,28,1])                imgs.append(d)                if self.one_hot:                    a=np.zeros(10)                    a[l]=1                    labels.append(l)                else:                    labels.append(l)            self.batch_offset_train+=1            return imgs,labels        else:            self.batch_offset_train=0            np.random.shuffle(self.train_datalabel)            return self.next_batch_train(batch_size)    def next_batch_val(self,batch_size):        ‘‘‘        return list of images with shape [N,784] or [N,28,28,1] dependents on self.data_dim               and list of labels with shape [N,1] or [N,10] dependents on self.one_hot        ‘‘‘        np.random.shuffle(self.valid_datalabel)        imgs=list();labels=list()        for d,l in self.train_datalabel[0:batch_size]:            if self.data_dim==3:                d=np.reshape(d, [28,28,1])            imgs.append(d)            if self.one_hot:                a=np.zeros(10)                a[l]=1                labels.append(l)            else:                labels.append(l)        return imgs,labels    def load_minist(self,dataset):        print "load dataset"        f = gzip.open(dataset, ‘rb‘)        train_set, valid_set, test_set = cPickle.load(f)        f.close()        self.train_x,self.train_y=train_set        self.valid_x,self.valid_y=valid_set        self.test_x , self.test_y=test_set        print "train image,label shape:",self.train_x.shape,self.train_y.shape        print "valid image,label shape:",self.valid_x.shape,self.valid_y.shape        print "test  image,label shape:",self.test_x.shape,self.test_y.shape        print "load dataset end"if __name__=="__main__":    mnist=MnistReader(‘../dataset/mnist.pkl.gz‘,data_dim=3)    data,label=mnist.next_batch_train(batch_size=1)    print data    print label

第三種載入方式需要 gzip和struct

import gzip, structdef _read(image,label):    minist_dir = ‘your_dir/‘    with gzip.open(minist_dir+label) as flbl:        magic, num = struct.unpack(">II", flbl.read(8))        label = np.fromstring(flbl.read(), dtype=np.int8)    with gzip.open(minist_dir+image, ‘rb‘) as fimg:        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))        image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)    return image,labeldef get_data():    train_img,train_label = _read(            ‘train-images-idx3-ubyte.gz‘,             ‘train-labels-idx1-ubyte.gz‘)    test_img,test_label = _read(            ‘t10k-images-idx3-ubyte.gz‘,             ‘t10k-labels-idx1-ubyte.gz‘)    return [train_img,train_label,test_img,test_label]

python讀取MNIST image資料

聯繫我們

該頁面正文內容均來源於網絡整理,並不代表阿里雲官方的觀點,該頁面所提到的產品和服務也與阿里云無關,如果該頁面內容對您造成了困擾,歡迎寫郵件給我們,收到郵件我們將在5個工作日內處理。

如果您發現本社區中有涉嫌抄襲的內容,歡迎發送郵件至: info-contact@alibabacloud.com 進行舉報並提供相關證據,工作人員會在 5 個工作天內聯絡您,一經查實,本站將立刻刪除涉嫌侵權內容。

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.