GAN is currently a very popular research direction in machine learning. There are two main types of research, one is to use Gan for interesting problems, and the other is to try to increase the stability of Gan model.
In fact, stability is very important in Gan training. The initial GAN model has some problems in training, e.g., pattern collapse (the generator evolves into a very narrow distribution, covering only a single pattern in the data distribution). pattern collapse means that a generator can produce only very similar samples (for example, a single number in a mnist), i.e. the resulting sample is not diverse . This certainly violates GAN 's original intention .
Another problem in Gan is that it does not refer to good indicators or metrics to illustrate the convergence of the model. The loss of generators and discriminator does not tell us any information about this. Of course, we can monitor the training process by looking at the data generated by the generator . However, this is a stupid manual process. So, we need an explanatory indicator to tell us the quality of the training process.
Wasserstein GAN
Wasserstein gan (wgan) is a new GAN algorithm, which can solve the two problems in a certain degree. For the intuition and the theoretical background behind Wgan, you can view the relevant information.
The pseudo code for the entire algorithm is as follows:
we can see that the algorithm is very similar to the original
GAN algorithm. However, for
Wgan, we need to note the following points according to the code above:
- there is no login the loss function. The output of the discriminant D (X) is no longer a probability (scalar), but it also means that there is no sigmoid activation function
- crop The weight W for the discriminant D (X)
- The number of times to train the discriminant is more than the generator
- Replaces the original ADAM optimizer with the rmsprop Optimizer
- Very low learning rate, α=0.00005
Wgan TensorFlow Implementation
The basic implementation of
GAN can be described in the previous article. We only need to modify the traditional GAN slightly. First, let's update our discriminant D (X)
"" "" ""def discriminator (x): = Tf.nn.relu (Tf.matmul (x, D_W1) + d_b1 ) = Tf.matmul (d_h1, d_w2) + d_b2 return tf.nn.sigmoid (out)"" " "" "def discriminator (x): = Tf.nn.relu (Tf.matmul (x, D_W1) + d_b1) = Tf.matmul (D_H1, d_w2) + d_b2 return out
View Code
Next, modify the loss function to remove the log:
"" "" "=-tf.reduce_mean (Tf.log (d_real) + tf.log (1. -=-Tf.reduce_mean (Tf.log (d_fake))"" "" "" = Tf.reduce_mean (d_real)-=-tf.reduce_mean (d_fake)
View Code
After each gradient drop update, the weight of the crop discriminant D (X) :
# Theta_d is List of D ' s params for in Theta_d]
Then, you just need to train more times for the discriminant D (X) .
D_solver = (Tf.train.RMSPropOptimizer (learning_rate=5e-5). Minimize (- D_loss, Var_list=theta_d)) g_solver = (Tf.train.RMSPropOptimizer (learning_ Rate=5e-5 =theta_ G) for it in Range (1000000 for _ ): X_mb, _ = Mnist.train.next_batch (mb_size) _, D_loss_curr, _ = Sess.run ([D_solver, D_loss, Clip_d], Feed_dict={x:x_mb, Z:sample_z (Mb_size, Z_dim)}) _, G_loss_curr = Sess.run ([G_solver, G_loss], feed_dict={z:sample_z (Mb_size, Z_dim)})
View Code
Wasserstein generative adversarial Nets (Wgan)