caffe中Python層的使用__Python

來源:互聯網
上載者:User

caffe的大多數層是由c++寫成的,藉助於c++的高效性,網路可以快速訓練。但是我們有時候需要自己寫點輸入層以應對各種不同的資料輸入,比如你因為是需要在映像中取塊而不想寫成LMDB,這時候可以考慮使用python直接寫一個層。而且輸入層不需要GPU加速,所需寫起來也比較容易。 python層怎麼用

先看一個網上的例子吧(來自http://chrischoy.github.io/research/caffe-python-layer/)

layer {  type: 'Python'  name: 'loss'  top: 'loss'  bottom: 'ipx'  bottom: 'ipy'  python_param {    # the module name -- usually the filename -- that needs to be in $PYTHONPATH    module: 'pyloss'    # the layer name -- the class name in the module    layer: 'EuclideanLossLayer'  }  # set loss weight so Caffe knows this is a loss layer  loss_weight: 1}

這裡的type就只有Python一種,然後top,bottom和常見的層是一樣的,module就是你的python module名字,一般就是檔案名稱,然後layer就是定義的類的名字。

一般setup、reshape、forword、backword四個函數是必須的,其他函數按自己的需求來補充,這四個函數格式如下:

def setup(self, bottom, top)、def reshape(self, bottom, top)、def forward(self, bottom, top)
def backward(self, top, propagate_down, bottom):


這裡就以 Fully Convolutional Networks for Semantic Segmentation 論文中公布的代碼作為樣本,解釋python層該怎麼寫。

import caffeimport numpy as npfrom PIL import Imageimport randomclass VOCSegDataLayer(caffe.Layer):    """ Load (input image, label image) pairs from PASCAL VOC one-at-a-time while reshaping the net to preserve dimensions. Use this to feed data to a fully convolutional network. """    def setup(self, bottom, top):        """ Setup data layer according to parameters: - voc_dir: path to PASCAL VOC year dir - split: train / val / test - mean: tuple of mean values to subtract - randomize: load in random order (default: True) - seed: seed for randomization (default: None / current time) for PASCAL VOC semantic segmentation. example params = dict(voc_dir="/path/to/PASCAL/VOC2011", mean=(104.00698793, 116.66876762, 122.67891434), split="val") """        # config        params = eval(self.param_str)        self.voc_dir = params['voc_dir']        self.split = params['split']        self.mean = np.array(params['mean'])        self.random = params.get('randomize', True)        self.seed = params.get('seed', None)        # two tops: data and label        if len(top) != 2:            raise Exception("Need to define two tops: data and label.")        # data layers have no bottoms        if len(bottom) != 0:            raise Exception("Do not define a bottom.")        # load indices for images and labels        split_f  = '{}/ImageSets/Segmentation/{}.txt'.format(self.voc_dir,                self.split)        self.indices = open(split_f, 'r').read().splitlines()        self.idx = 0        # make eval deterministic        if 'train' not in self.split:            self.random = False        # randomization: seed and pick        if self.random:            random.seed(self.seed)            self.idx = random.randint(0, len(self.indices)-1)    def 

聯繫我們

該頁面正文內容均來源於網絡整理,並不代表阿里雲官方的觀點,該頁面所提到的產品和服務也與阿里云無關,如果該頁面內容對您造成了困擾,歡迎寫郵件給我們,收到郵件我們將在5個工作日內處理。

如果您發現本社區中有涉嫌抄襲的內容,歡迎發送郵件至: info-contact@alibabacloud.com 進行舉報並提供相關證據,工作人員會在 5 個工作天內聯絡您,一經查實,本站將立刻刪除涉嫌侵權內容。

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.