Pytorch中文文檔已出(http://pytorch-cn.readthedocs.io/zh/latest/)。第一篇部落格獻給了pytorch,主要是為了整理自己的思路。
原來使用caffe,總是要編譯,經曆了無數的坑。當開始接觸pytorch時,果斷拔草caffe。
學習Pytorch最好有一些深度學習理論基礎才更好開,廢話不多說,進入主題。 1 先有個框框,再往裡面填東西
當訓練一個神經網路的時候,我們需要有資料,有模型,並且需要設定訓練的參數。為了不亂,我們最好分別定義三個檔案,分別是:資料準備和預先處理traindataset.py+編寫模型model.py+如何訓練main.py(xx.py,xx自己可任意取名)。
今天我們只講資料準備與預先處理階段:traindataset.py(怎樣命名無所謂,as you like)。這個檔案的作用是什麼呢。
統一將映像(或矩陣)返回成torch能處理的[original_iamges.tensor,label.tensor]
我們先跳躍一下看中文介紹是如何匯入資料:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
我們一般關注DataLoader四個參數:
dataset, batch_size, shuffle, num_workers=0
batch_size是你批處理數目,shuffle是否每個epoch都打亂,workers是載入資料的線程數(請查看中文文檔對每個參數的解釋)
我們具體看看“dataset”——載入資料的資料集。
這個dataset應該是[original_iamges.tensor,label.tensor]之類的,我們定義的“traindataset.py”就是產生這個dataset的。
你只需在main.py 檔案import就可調用。
from traindataset import *
2 定義一個py檔案產生我們自己的dataset 這個py檔案一定要1:能輸入我自己的資料路徑 2:還得預先處理吧,比如的裁剪啊~ step 1:先匯入你肯定需要的庫路徑
import torch.utils.data import torch
from tochvision import transforms
t orch.utils.data模組是子類化你的資料 transforms庫對資料預先處理 step 2:自訂dataset類(子類化你的資料)
class MyTrainData(torch.utils.data.Dataset)
這裡繼承了torch.utils.data.Dataset這個類,我們看看這個類在中文文檔中介紹: 所有其他資料集都應該進行子類化。所有子類應該override__len__和__getitem__,前者提供了資料集的大小,後者支援整數索引,範圍從0到len(self)。
當然還有個初始化__init__()
ok,我們臉譜化py檔案,再往裡面加東西(以下為基礎架構):
#encoding:utf-8import torch.utils.data as dataimport torchfrom torchvision import transformsclass trainmydatalala(torch.utils.data.Dataset) #子類化 def __init__(self, root, transform=None, train=True): #第一步初始化 self.root = root
self.train = train
def __getitem__(self, idx): #第二步裝載資料,返回[img,label] img = imread(img_path) img = torch.from_numpy(img).float() gt = imread(gt_path) gt = torch.from_numpy(gt).float() return img, gt def __len__(self): return len(self.imagenumber)
現在往框框裡面填:(1)是否transform如裁剪、歸一化、旋轉等。(2)是否區分test和train。(3)如何做到一張一張對應讀取圖片。 Python讀取圖片需要scipy庫,要想批處理讀取,還需要os庫; 以下貼出完整代碼:
#encoding:utf-8import torch.utils.data as dataimport torchfrom scipy.ndimage import imreadimport osimport os.pathimport globfrom torchvision import transformsdef make_dataset(root, train=True): dataset = [] if train: dirgt = os.path.join(root, 'train_data/groundtruth') dirimg = os.path.join(root, 'train_data/imgs') for fGT in glob.glob(os.path.join(dirgt, '*.jpg')): # for k in range(45) fName = os.path.basename(fGT) fImg = 'train_ori'+fName[8:] dataset.append( [os.path.join(dirimg, fImg), os.path.join(dirgt, fName)] ) return dataset#自定義dataset的架構class kaggle2016nerve1(data.Dataset): #需要繼承data.Dataset def __init__(self, root, transform=None, train=True): #初始設定檔案路進或檔案名稱 self.train = train if self.train: self.train_set_path = make_dataset(root, train) def __getitem__(self, idx): if self.train: img_path, gt_path = self.train_set_path[idx] img = imread(img_path) img = np.atleast_3d(img).transpose(2, 0, 1).astype(np.float32) img = (img - img.min()) / (img.max() - img.min()) img = torch.from_numpy(img).float() gt = imread(gt_path) gt = np.atleast_3d(gt).transpose(2, 0, 1) gt = gt / 255.0 gt = torch.from_numpy(gt).float() return img, gt def __len__(self): return len(self.train_set_path)
這裡的py檔案需要在最後main.py檔案中調用,所以root我並沒有賦值,我會在main,py中賦值。 這裡我並沒有用到“transform”進行預先處理,如果你想用的話,在__getitem__()下面,return img,gt前重新賦值 img = transforms.ToTensor(img)以及gt = transforms.ToTensor(gt)
這需要注意的是,查看中文文檔transforms庫有哪些變換,如果有需要涉及參數的如CenterCrop(size),需要先實參化,如
crop = transforms.CenterCrop(10);再使用:img = crop(img)