Pytorch Chinese document is out (http://pytorch-cn.readthedocs.io/zh/latest/). The first blog dedicated to the Pytorch, mainly to organize their own ideas.
The original use of Caffe, always to compile, experienced countless pits. When beginning to contact Pytorch, decisive weeding Caffe.
Learning Pytorch the best to have some deep learning theoretical basis to better open, nonsense not much to say, into the subject. 1 There's a box, and then I'll fill it in.
When training a neural network, we need to have data, have models, and need to set parameters for training. In order not to mess up, we'd better define three files respectively: Data preparation and preprocessing traindataset.py+ writing model model.py+ How to train main.py (xx.py,xx own can be arbitrarily named).
Today we only talk about the data preparation and preprocessing phase: traindataset.py (how to name doesn't matter, as you like). What is the role of this file?
Unify the return of an image (or matrix) to a torch capable of processing [original_iamges.tensor,label.tensor]
let's jump first. The introduction of Chinese is how to import data:
Torch.utils.data.DataLoader (DataSet, Batch_size=1, Shuffle=false, Sampler=none, Num_workers=0, collate_fn=< function Default_collate>, Pin_memory=false, Drop_last=false)
We generally focus on Dataloader four parameters:
DataSet, Batch_size, Shuffle, num_workers=0
Batch_size is the number of batches you batch, shuffle whether each epoch is scrambled, workers is the number of threads that load the data (see the Chinese documentation for each parameter explanation)
Let's take a look at "dataset"--the data set that loads the data.
This dataset is supposed to be [original_iamges.tensor,label.tensor], and the "traindataset.py" we define is the one that produces the dataset.
you can call it only in the main.py file import.
From Traindataset Import *
2 Define a py file to generate our own dataset this py file must be 1: You can enter my own data path 2: You have to preprocess it, such as cutting AH ~ Step 1: First import the library path you must need
Import torch.utils.data
Import Torch
From tochvision Import transforms
The T Orch.utils.data module is a subclass of your data transforms library for data preprocessing Step 2: Customizing the DataSet Class (subclass your data)
Class Mytraindata (Torch.utils.data.Dataset)
Here we inherit the Torch.utils.data.Dataset class, and let's look at this class in the Chinese documentation: all other datasets should be sub-typed. All subclasses should be override__len__ and __getitem__, which provide the size of the dataset, which supports integer indexes ranging from 0 to Len (self).
And of course there's an initialization __init__ ()
OK, let's put the py file on Facebook and add something to it (below is the base frame):
#encoding: Utf-8
import torch.utils.data as data
import torch from
torchvision import transforms
class Trainmydatalala (Torch.utils.data.Dataset) #子类化
def __init__ (self, root, Transform=none, train=true): #第一步初始化
Self.root = root
Self.train = Train
def __getitem__ (self, idx): #第二步装载数据, returns [Img,label]
img = imread (img_path)
img = Torch.from_numpy (img). float ()
GT = Imread (gt_path)
GT = Torch.from_numpy (GT). Float ()
return img, GT
def __len__ (self):
Return Len (Self.imagenumber)
Now fill in the box: (1) Whether transform such as cropping, normalization, rotation and so on. (2) whether to distinguish between test and train. (3) How to do a picture corresponding to the reading. Python reads the image requires SciPy library, in order to batch read, but also need the OS library; The following is the full code:
#encoding: Utf-8 import torch.utils.data as data import torch from scipy.ndimage import imread import os import Os.path IM Port Glob from torchvision import transforms Def make_dataset (Root, train=true): DataSet = [] If TRAIN:DIRGT = Os.path.join (Root, ' train_data/groundtruth ') dirimg = Os.path.join (Root, ' Train_data/imgs ') for FGT in GLOB.GL OB (Os.path.join (DIRGT, ' *.jpg ')): # for k in range fName = Os.path.basename (fGT) fimg = ' Train_ori ' +fname[8:] Dataset.append ([Os.path.join (Dirimg, fimg), Os.path.join (DIRGT, FName)]) return dataset #自定義datase T's Framework class KAGGLE2016NERVE1 (data. Dataset): #需要繼承data.
Dataset def __init__ (self, root, Transform=none, train=true): #初始化文件路進或文件名 Self.train = Train if Self.train: Self.train_set_path = Make_dataset (root, train) def __getitem__ (self, idx): If Self.train:img_path, gt_p ATH = self.train_set_path[idx] img = imread (img_path) img = np.atleast_3d (IMG). Transpose (2, 0, 1). Astype (Np.float32) img = (img-img.min ())/(Img.max ()-Img.min ()) img = Torch.from_nump Y (img). FLOAT () GT = Imread (gt_path) GT = np.atleast_3d (GT). Transpose (2, 0, 1) GT = gt/255.0 GT
= Torch.from_numpy (GT). FLOAT () return img, GT def __len__ (self): return len (Self.train_set_path)
The py file here needs to be called in the last main.py file, so root I am not assigned, I will assign a value in Main,py. I do not use "transform" for preprocessing, if you want to use, under __getitem__ (), return IMG,GT before re-assignment img = transforms. Totensor (img) and GT = Transforms. Totensor (GT)
It is important to note that you can see what transformations are in the transforms library of the Chinese document, and if there is a need to involve parameters such as centercrop (size), you need to take the arguments first, as
crop = transforms. Centercrop (10); re-use: IMG = Crop (img)