GAN Theories
Generative Adversarial Network [2014 Goodfellow]
基本组件
Discriminator(判别器): 用来判断图片是真实的还是生成的
Generator(生成器): 生成使判别器无法判别真假的图片
Formula
Objective
我们的目标是得到一个理想的生成器(或者判别器)
这里提醒一下,我们理想的生成器是通过学会将输入数据的数据分布改变成真实数据的数据分布来实现的。
自己写的马上要出去吃饭了,懒得翻译成中文了,姑且打个引用标记
For the data set, after converting the data into vectors, the data of people has its own distribution, and the data of cars also has its own distribution. The data distribution of cars is definitely not the same as the data distribution of people. For two-dimensional data with two Gaussian distributions, the difference between the two data distributions is equivalent to the difference in expectation and variance. Of course, the actual data is much more complex than the above example, we can only understand their distribution abstractly
For traditional network, we need paired data to train the model (sample to sample). For example, if we want the model to de the colorization. we need a gray picture of chicken and corresponding color picture.
But if we let the model learn the distribution of color picture, we can use other color picture like car, horse and human to let the model do the colorize job for chicken picture.
But how can we prevent the model turning the chicken picture into horse? Remember, we can stack another loss function to discriminate the picture. In this way, we can colorize the picture and leave the object's appearance unchanged in space
代码实现
Gradient
Problem
由于在训练刚开始的时候,generator的能力很差,导致discriminator非常容易判别,这样Generator的梯度就会非常小,达不到训练的效果
solutions
更改loss函数
![image-20220522180623346](01 GAN 基础.assets/image-20220522180623346.png)
使用负标签(见实例代码)
训练代码示例(pytorch)
criterion = nn.BCELoss()
valid = Tensor(imgs.size(0), 1).fill_(1.0).detach()
fake = Tensor(imgs.size(0), 1).fill_(0.0).detach()
# detach means "requires_grad = False"
# Train Generator
optimizer_G.zero_grad()
z = Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim)))
gen_imgs = generator(z)
g_loss = criterion(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# Train Discriminator
optimizer_D.zero_grad()
real_loss = criterion(discriminator(real_imgs), valid)
fake_loss = criterion(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
总结
最初GAN的问题
- 难以训练
- 容易得到相同的图片(以MNIST数据集为例非常容易生成1和7)
- 训练过程很不稳定,会出现效果倒退的情况
GAN相关代码已上传至GitHub RottenTangerine/ML/GAN at master · (github.com)