pytorch學習1:如何載入自己的訓練資料

來源:互聯網
上載者:User


Pytorch中文文檔已出(http://pytorch-cn.readthedocs.io/zh/latest/)。第一篇部落格獻給了pytorch,主要是為了整理自己的思路。

原來使用caffe,總是要編譯,經曆了無數的坑。當開始接觸pytorch時,果斷拔草caffe。

學習Pytorch最好有一些深度學習理論基礎才更好開,廢話不多說,進入主題。 1 先有個框框,再往裡面填東西

當訓練一個神經網路的時候,我們需要有資料,有模型,並且需要設定訓練的參數。為了不亂,我們最好分別定義三個檔案,分別是:資料準備和預先處理traindataset.py+編寫模型model.py+如何訓練main.py(xx.py,xx自己可任意取名)。

今天我們只講資料準備與預先處理階段:traindataset.py(怎樣命名無所謂,as you like)。這個檔案的作用是什麼呢。

統一將映像(或矩陣)返回成torch能處理的[original_iamges.tensor,label.tensor]

我們先跳躍一下看中文介紹是如何匯入資料:

 torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
我們一般關注DataLoader四個參數:
dataset, batch_size, shuffle, num_workers=0
batch_size是你批處理數目,shuffle是否每個epoch都打亂,workers是載入資料的線程數(請查看中文文檔對每個參數的解釋)

我們具體看看“dataset”——載入資料的資料集。

這個dataset應該是[original_iamges.tensor,label.tensor]之類的,我們定義的“traindataset.py”就是產生這個dataset的。


你只需在main.py 檔案import就可調用。

from traindataset import *

2 定義一個py檔案產生我們自己的dataset 這個py檔案一定要1:能輸入我自己的資料路徑 2:還得預先處理吧,比如的裁剪啊~ step 1:先匯入你肯定需要的庫路徑

import torch.utils.data import torch
from tochvision import transforms
t orch.utils.data模組是子類化你的資料 transforms庫對資料預先處理 step 2:自訂dataset類(子類化你的資料)
class MyTrainData(torch.utils.data.Dataset)
這裡繼承了torch.utils.data.Dataset這個類,我們看看這個類在中文文檔中介紹: 所有其他資料集都應該進行子類化。所有子類應該override__len__和__getitem__,前者提供了資料集的大小,後者支援整數索引,範圍從0到len(self)。
 當然還有個初始化__init__() 
ok,我們臉譜化py檔案,再往裡面加東西(以下為基礎架構): 
 
#encoding:utf-8import torch.utils.data as dataimport torchfrom torchvision import transformsclass trainmydatalala(torch.utils.data.Dataset) #子類化  def __init__(self, root, transform=None, train=True): #第一步初始化    self.root = root   
    self.train = train
  def __getitem__(self, idx): #第二步裝載資料,返回[img,label]      img = imread(img_path)      img = torch.from_numpy(img).float()      gt = imread(gt_path)          gt = torch.from_numpy(gt).float()      return img, gt   def __len__(self):    return len(self.imagenumber)

現在往框框裡面填:(1)是否transform如裁剪、歸一化、旋轉等。(2)是否區分test和train。(3)如何做到一張一張對應讀取圖片。 Python讀取圖片需要scipy庫,要想批處理讀取,還需要os庫; 以下貼出完整代碼:
#encoding:utf-8import torch.utils.data as dataimport torchfrom scipy.ndimage import imreadimport osimport os.pathimport globfrom torchvision import transformsdef make_dataset(root, train=True):  dataset = []  if train:    dirgt = os.path.join(root, 'train_data/groundtruth')     dirimg = os.path.join(root, 'train_data/imgs')    for fGT in glob.glob(os.path.join(dirgt, '*.jpg')):    # for k in range(45)      fName = os.path.basename(fGT)          fImg = 'train_ori'+fName[8:]      dataset.append( [os.path.join(dirimg, fImg), os.path.join(dirgt, fName)] )  return dataset#自定義dataset的架構class kaggle2016nerve1(data.Dataset):   #需要繼承data.Dataset  def __init__(self, root, transform=None, train=True): #初始設定檔案路進或檔案名稱    self.train = train    if self.train:      self.train_set_path = make_dataset(root, train)  def __getitem__(self, idx):    if self.train:      img_path, gt_path = self.train_set_path[idx]      img = imread(img_path)      img = np.atleast_3d(img).transpose(2, 0, 1).astype(np.float32)      img = (img - img.min()) / (img.max() - img.min())      img = torch.from_numpy(img).float()      gt = imread(gt_path)      gt = np.atleast_3d(gt).transpose(2, 0, 1)      gt = gt / 255.0      gt = torch.from_numpy(gt).float()      return img, gt    def __len__(self):    return len(self.train_set_path)   
這裡的py檔案需要在最後main.py檔案中調用,所以root我並沒有賦值,我會在main,py中賦值。 這裡我並沒有用到“transform”進行預先處理,如果你想用的話,在__getitem__()下面,return img,gt前重新賦值 img = transforms.ToTensor(img)以及gt = transforms.ToTensor(gt) 這需要注意的是,查看中文文檔transforms庫有哪些變換,如果有需要涉及參數的如CenterCrop(size),需要先實參化,如 crop = transforms.CenterCrop(10);再使用:img = crop(img)

聯繫我們

該頁面正文內容均來源於網絡整理,並不代表阿里雲官方的觀點,該頁面所提到的產品和服務也與阿里云無關,如果該頁面內容對您造成了困擾,歡迎寫郵件給我們,收到郵件我們將在5個工作日內處理。

如果您發現本社區中有涉嫌抄襲的內容,歡迎發送郵件至: info-contact@alibabacloud.com 進行舉報並提供相關證據,工作人員會在 5 個工作天內聯絡您,一經查實,本站將立刻刪除涉嫌侵權內容。

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.