Pytorch Study Notes (20): Ignite (Advanced API for training models)

Source: Internet
Author: User
Tags abstract documentation pytorch

This picture has expired, 2018.04.04 version, there is no Trainer and Evaluator class, only one Engine class left

Recently I want to write a higher level of abstraction to more convenient training Pytorch network, inadvertently found that pytorch users under a ignite repo, curious to see what this is a thing. The original is Pytorch has provided a high-level abstract library to train the Pytorch model, since there is a wheel, then there is no need to build their own, good use on the line. Read the source, you can also learn how the big boys are abstract. Needless to say, because Ignite currently lacks official documentation (API documentation and an overview of the library coming soon.), so this blog is mainly about ignite to do a macro introduction.

There are no official documents, but there are examples of what we can do with the ignite example.

In order to reduce the length of the source code, specifically to the ignite relationship with a small number of deleted, if you want to run the complete example, you can see the link mentioned above.

From Argparse import argumentparser from torch Import nn from torch.optim import SGD from torchvision.transforms import Co
Mpose, Totensor, Normalize from ignite.engines import Events, Create_supervised_trainer, Create_supervised_evaluator From ignite.metrics import categoricalaccuracy, Loss def run (train_batch_size, val_batch_size, epochs, LR, momentum, Log_ Interval): Cuda = torch.cuda.is_available () train_loader, Val_loader = Get_data_loaders (Train_batch_size, VAL_BATC h_size) model = Net () If Cuda:model = Model.cuda () optimizer = SGD (Model.parameters (), LR=LR, Moment Um=momentum) trainer = Create_supervised_trainer (model, optimizer, F.nll_loss, Cuda=cuda) evaluator = Create_super
                                                     Vised_evaluator (model, metrics={' accuracy ': categoricalaccuracy (), 

    ' NLL ': Loss (F.nll_loss)}, Cuda=cuda) @trainer. On (events.iteration_completed) def log_training_loss (engine): iter = (engine.iteration-1)% len (train_loader) + 1  If iter% Log_interval = = 0:print ("epoch[{}] iteration[{}/{}] Loss: {:. 2f}". Format (Engine.state.epoch, ITER, 
        Len (Train_loader), engine.state.output)) @trainer. On (events.epoch_completed) def log_validation_results (engine): Metrics = Evaluator.run (val_loader). Metrics avg_accuracy = metrics[' accuracy '] AVG_NLL = metrics[ ' NLL '] print ("Validation Results-epoch: {} avg Accuracy: {:. 2f} Avg Loss: {:. 2f}". Format (engine. State.epoch, Avg_accuracy, AVG_NLL)) Trainer.run (Train_loader, Max_epochs=epochs)

make a summary of the process first, then see what the API did to create the model, create Dataloader create trainer create evaluator for some event registration functions, @trainer. On () Trainer.run ()

Event

"" "
is similar to an enumeration class that defines several events" ""
class Events (Enum):
    epoch_started = "epoch_started"               # when a new EPOCH This event is triggered at start
    epoch_completed = "epoch_completed"           # when an EPOCH ends, this event is triggered
    STARTED = "STARTED"                           # Start training model Yes, This event is triggered
    completed = "Completed"                       # when training ends, this event is triggered
    iteration_started = "iteration_started"       # when a This event is triggered when iteration starts
    iteration_completed = "iteration_completed"   # When a iteration ends, this event is triggered
    EXCEPTION _raised = "exception_raised"         # This event is triggered when an exception occurs

State

Class state (object):
    def __init__ (self, **kwargs):
        self.iteration = 0            # record iteration
        self.output = None            # The output of the current iteration. For supervised Trainer, it is loss.
        Self.batch = None             # Mini-batch Sample for iteration this time for
        K, V in Kwargs.items ():   # Other states that want State to record
            SetAt TR (self, k, V)

Create_supervised_trainer

def create_supervised_trainer (model, optimizer, LOSS_FN, Cuda=false): "" "
    Factory function for creating a Trainer for supervised models

    Args:
        model (Torch.nn.Module): The model to train
        optimizer ( Torch.optim.Optimizer): The Optimizer to use
        loss_fn (torch.nn loss function): The loss function-to
        -use CUDA (bool , optional): Whether or not to transfer batch to GPU (Default:false)

    Returns:
        trainer:a Trainer instance with Su pervised update function
    "" "

Create_supervised_evaluator

def create_supervised_evaluator (model, metrics={}, Cuda=false): "" "
    Factory function for creating an evaluator For supervised models

    Args:
        model (Torch.nn.Module): The model to train
        metrics (Dict of Str:metric): a map O F metric names to Metrics
        Cuda (bool, optional): Whether or not to transfer batch to GPU (Default:false)

    Returns:
        evaluator:a Evaluator Instance with supervised inference function
    "" "

Trainer

# Inherit from Engine
def __init__ (self, process_function):
    pass 

"" "
process_function signature is func (batch) ->anything
def func (Batch): # batch is saved in State.batch
    1. Process Batch
    2. Forward Compution
    3. Compute loss
    4. Computer gradient
    5. Update parameters
    6. Return loss or else # value returned will be saved in State.output > "" "" "


for an event register function, when the event occurs, this function will be called
function signature must be def func (trainer, State)" ""
@trainer. )
def some_func (trainer):
    pass

Trainer.run () # Training model

Evaluator

# inherits from Engine def __init__ (self, process_function): Pass "" "Process_function of Signature is Func (Batch)->anything def func (Batch): # batch is saved in State.batch 1. Process Batch 2. Forward Compution 3.
Return something # The value returned is saved in State.output, # is used to calculate Metric "" # for Evaluator some event registration functions. 
@evaluator. On (...) def func (evaluator): Pass Evaluator.run () # performs the calculation. Returns the results of the metrics calculations on the state state.metrics # validation set are saved here 

Contact Us

The content source of this page is from Internet, which doesn't represent Alibaba Cloud's opinion; products and services mentioned on that page don't have any relationship with Alibaba Cloud. If the content of the page makes you feel confusing, please write us an email, we will handle the problem within 5 days after receiving your email.

If you find any instances of plagiarism from the community, please send an email to: info-contact@alibabacloud.com and provide relevant evidence. A staff member will contact you within 5 working days.

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.