The next chapter.
NET andPropagation具备后,我们就可以训练了。训练师要做的事情就是,怎么把一大批样本分成小批训练,然后把小批的结果合并成完整的结果(批量/增量);什么时候调用学习师根据训练的结果进行学习,然后改进网络的权重和状态;什么时候决定训练结束。
So what do these two teachers look like and how do they do it?
Public interface Trainer {public void train (Net net,dataprovider provider); Public interface Learner {public void Learn (Net net,trainresult trainresult);
The so-called trainer is the given data, the training of the specified network, so-called learner is a given training results, and then the specified network weight adjustment.
A simple implementation of these two interfaces is given below.
Trainer
The
Trainer implements a simple batch training function that stops after a given number of iterations. A code example is shown below.
public class Commontrainer implements Trainer {int ecophs; Learner learner; list<double> costs = new arraylist<> (); list<double> Accuracys = new arraylist<> (); int batchsize = 1;public commontrainer (int ecophs, Learner Learner {super (); this.ecophs = Ecophs;this.learner = Learner = = null? New Momentadaptlearner (): learner;} Public Commontrainer (int ecophs, learner learner, int. batchsize) {This (ecophs, learner); this.batchsize = batchsize;} public void Trainone (final net NET, Dataprovider provider) {Final propagation propagation = new Propagation (Net);D Oublemat Rix input = Provider.getinput ();D Oublematrix target = Provider.gettarget (); final int alllen = target.columns;final int[] n Odesnum = Net.getnodesnum (); final int layersnum = Net.getlayersnum (); list<doublematrix> inputbatches = this.getbatches (input); final list<doublematrix> targetBatches = This.getbatches (target); final list<integer> Batchlen = matrixutil.getendposition (targetbatches); final BackwaRdresult Backwardresult = new Backwardresult (NET, Alllen); Batch parallel Training Parallel.For (inputbatches, New parallel.operation<doublematrix> () {@Overridepublic void perform (int Index, Doublematrix subinput) {Forwardresult Subresult = Propagation.forward (subinput);D Oublematrix subTarget = Targetbatches.get (index); Backwardresult Backresult = Propagation.backward (subtarget,subresult);D oublematrix cost = backwardresult.cost;d Oublematrix accuracy = Backwardresult.accuracy;doublematrix Inputdeltas = Backwardresult.getinputdelta (); int start = index = = 0? 0:batchlen.get (index-1); int end = Batchlen.get (index)-1;int[] Cindexs = Arrayshelper.makearray (start, end); Cost.put ( Cindexs, Backresult.cost); if (accuracy! = null) {Accuracy.put (Cindexs, backresult.accuracy);} Inputdeltas.put (Arrayshelper.makearray (0, Nodesnum[0]-1), Cindexs, Backresult.getinputdelta ()); for (int i = 0; i < l Ayersnum; i++) {Doublematrix gradients = backwardResult.gradients.get (i);D oublematrix biasgradients =BackwardResult.biasGradients.get (i); Doublematrix subgradients = BackResult.gradients.get (i). Muli (backResult.cost.columns);D Oublematrix Subbiasgradients = BackResult.biasGradients.get (i). Muli (BackResult.cost.columns); Gradients.addi (subgradients); Biasgradients.addi (subbiasgradients);}}); The mean for (Doublematrix gradient:backwardResult.gradients) {gradient.divi (Alllen);} for (Doublematrix gradient:backwardResult.biasGradients) {gradient.divi (Alllen);} This.mergebackwardresult (backresults, net, input.columns); Trainresult Trainresult = new Trainresult (null, Backwardresult), Learner.learn (NET, trainresult);D ouble cost = Backwardresult.getmeancost ();D ouble accuracy = backwardresult.getmeanaccuracy (); if (cost! = NULL) Costs.add (cost); if ( accuracy = NULL) accuracys.add (accuracy); System.out.println (cost); SYSTEM.OUT.PRINTLN (accuracy);} @Overridepublic void train (NET NET, Dataprovider provider) {for (int i = 0; i < this.ecophs; i++) {This.trainone (NET, p Rovider);}}}
Learner
Learner is the specific adjustment algorithm, when the gradient is calculated, it is responsible for the network weight adjustment. The choice of adjustment algorithm directly influences the speed of network convergence. The implementation of this paper uses a simple momentum-adaptive learning rate algorithm.
Its iterative formula is as follows:
$ $W (t+1) =w (t) +\delta W (t) $$
$$\delta W (t) =rate (t) (1-moment (t)) G (t+1) +moment (t) \delta W (t-1) $$
$ $rate (t+1) =\begin{cases} rate (t) \times 1.05 & \mbox{if} cost (t) <cost (t-1) \ Rate (t) \times 0.7 & \mbox{else I F} cost (T) <cost (t-1) \times 1.04\\ 0.01 & \mbox{else} \end{cases}$$
$ $moment (t+1) =\begin{cases} 0.9 & \mbox{if} cost (t) <cost (t-1) \ Rate (t) \times 0.7 & \mbox{else if} cost (t) < ; Cost (t-1) \times 1.04\\ 1-0.9 & \mbox{else} \end{cases}$$
The sample code is as follows:
public class Momentadaptlearner implements learner {Net net;double moment = 0.9;double LMD = 1.05;double Precost = 0;doubl E eta = 0.01;double currenteta=eta;double currentmoment=moment; Trainresult pretrainresult;public Momentadaptlearner (double moment, double eta) {super (); this.moment = Moment;this.eta = Eta;this.currenteta=eta;this.currentmoment=moment;} @Overridepublic void Learn (NET net, Trainresult Trainresult) {if (this.net = = null) init (NET); Backwardresult Backwardresult = Trainresult.backwardresult; Backwardresult Prebackwardresult = pretrainresult.backwardresult;double cost=backwardresult.getmeancost (); This.modifyparameter (cost); System.out.println ("Current eta:" +this.currenteta); System.out.println ("Current moment:" +this.currentmoment); for (int j = 0; J < Net.getlayersnum (); J + +) {Doublematrix weight = net.getweights (). Get (j);D Oublematrix gradient = BackwardResult.gradients.get (j); gradient = Gradient.muli (Currenteta * (1-this.currentmoment)). Addi (preBackwardResult.gradients.geT (j). Muli (this.currentmoment));p ReBackwardResult.gradients.set (j, gradient); Weight.subi (gradient);D Oublematrix b = Net.getbs (). Get (j);D Oublematrix bgradient = BackwardResult.biasGradients.get (j); bgradient = Bgradient.muli ( Currenteta * (1-this.currentmoment)). Addi (PreBackwardResult.biasGradients.get (j). Muli (this.currentmoment)); PreBackwardResult.biasGradients.set (J, Bgradient); B.subi (bgradient);}} public void Modifyparameter (double cost) {if (cost<this.precost) {this.currenteta*=1.05;this.currentmoment=moment ;} else if (cost<1.04*this.precost) {this.currenteta*=0.7;this.currentmoment*=0.7;} Else{this.currenteta=eta;this.currentmoment=1-moment;} This.precost=cost;} public void init (net Net) {this.net = net; Backwardresult bresult = new Backwardresult (); for (Doublematrix weight:net.getWeights ()) {BResult.gradients.add ( Doublematrix.zeros (Weight.rows,weight.columns));} For (Doublematrix B:net.getbs ()) {BResult.biasGradients.add (Doublematrix.zeros (B.rows, B.columns));} Pretrainresult=new TraInresult (Null,bresult);}}
Now, a simple neural network from generation to training has been implemented simply.
Write BP neural network in Java (II.)