caffe的大多數層是由c++寫成的,藉助於c++的高效性,網路可以快速訓練。但是我們有時候需要自己寫點輸入層以應對各種不同的資料輸入,比如你因為是需要在映像中取塊而不想寫成LMDB,這時候可以考慮使用python直接寫一個層。而且輸入層不需要GPU加速,所需寫起來也比較容易。 python層怎麼用
先看一個網上的例子吧(來自http://chrischoy.github.io/research/caffe-python-layer/)
layer { type: 'Python' name: 'loss' top: 'loss' bottom: 'ipx' bottom: 'ipy' python_param { # the module name -- usually the filename -- that needs to be in $PYTHONPATH module: 'pyloss' # the layer name -- the class name in the module layer: 'EuclideanLossLayer' } # set loss weight so Caffe knows this is a loss layer loss_weight: 1}
這裡的type就只有Python一種,然後top,bottom和常見的層是一樣的,module就是你的python module名字,一般就是檔案名稱,然後layer就是定義的類的名字。
一般setup、reshape、forword、backword四個函數是必須的,其他函數按自己的需求來補充,這四個函數格式如下:
def setup(self, bottom, top)、def reshape(self, bottom, top)、def forward(self, bottom, top)
def backward(self, top, propagate_down, bottom):
這裡就以 Fully Convolutional Networks for Semantic Segmentation 論文中公布的代碼作為樣本,解釋python層該怎麼寫。
import caffeimport numpy as npfrom PIL import Imageimport randomclass VOCSegDataLayer(caffe.Layer): """ Load (input image, label image) pairs from PASCAL VOC one-at-a-time while reshaping the net to preserve dimensions. Use this to feed data to a fully convolutional network. """ def setup(self, bottom, top): """ Setup data layer according to parameters: - voc_dir: path to PASCAL VOC year dir - split: train / val / test - mean: tuple of mean values to subtract - randomize: load in random order (default: True) - seed: seed for randomization (default: None / current time) for PASCAL VOC semantic segmentation. example params = dict(voc_dir="/path/to/PASCAL/VOC2011", mean=(104.00698793, 116.66876762, 122.67891434), split="val") """ # config params = eval(self.param_str) self.voc_dir = params['voc_dir'] self.split = params['split'] self.mean = np.array(params['mean']) self.random = params.get('randomize', True) self.seed = params.get('seed', None) # two tops: data and label if len(top) != 2: raise Exception("Need to define two tops: data and label.") # data layers have no bottoms if len(bottom) != 0: raise Exception("Do not define a bottom.") # load indices for images and labels split_f = '{}/ImageSets/Segmentation/{}.txt'.format(self.voc_dir, self.split) self.indices = open(split_f, 'r').read().splitlines() self.idx = 0 # make eval deterministic if 'train' not in self.split: self.random = False # randomization: seed and pick if self.random: random.seed(self.seed) self.idx = random.randint(0, len(self.indices)-1) def