DCgAN 只产生噪音

问题描述

我是 GAN 的新手,我在 mnist 上训练 DCGAN 时遇到了麻烦,当我使用带有线性层的 GAN 时,一切都很好,并且生成生成了非常好的图像。但是当我开始使用卷积 GAN 时,生成生成器只产生噪音。有人知道如何解决这个问题吗?

Generated images

Discriminator loss and generator loss

这是我的神经网络和我的训练循环,

神经网络:

鉴别器

class discriminator(nn.Module):
    def __init__(self,ch,F):
        super(discriminator,self).__init__()

        def block(in_ch,out_ch,k,s,p,final=False,bn=True):
            block = []
            block.append(nn.Conv2d(in_ch,kernel_size=k,stride=s,padding=p))
            if not final and bn:
                block.append(nn.Batchnorm2d(out_ch))
                block.append(nn.LeakyReLU(0.2))
            elif not bn and not final:
                block.append(nn.LeakyReLU(0.2))
            elif final and not bn:
                block.append(nn.Sigmoid())
            return block
    
        self.D = nn.Sequential(
            *block(ch,F,k=4,s=2,p=1,bn=False),*block(F,F*2,p=1),*block(F*2,F*4,*block(F*4,F*8,*block(F*8,1,p=0,final=True,bn=False)
        )

    def forward(self,x): return self.D(x)

发电机

class Generator(nn.Module):
    def __init__(self,ch_noise,ch_img,features_g):
        super(Generator,final=False):
            block = []
            block.append(nn.ConvTranspose2d(in_ch,padding=p))
            if not final:
                block.append(nn.Batchnorm2d(out_ch))
                block.append(nn.ReLU())
            if final:
                block.append(nn.Tanh())
            return block

        self.G = nn.Sequential(
            *block(ch_noise,features_g*16,s=1,p=0),*block(features_g*16,features_g*8,*block(features_g*8,features_g*4,*block(features_g*4,features_g*2,*block(features_g*2,final=True)
        )
    def forward(self,z): return self.G(z)

训练循环:

loss_G,loss_D = [],[]
for i in range(epochs):
    D.train()
    G.train()
    st = time.time()
    for idx,(img,_ ) in enumerate(mnist):
        img = img.to(device)
        ## discriminator ##
        D.zero_grad(set_to_none=True)
        lable = torch.ones(bs,device=device)*0.9
        pred = D(img).reshape(-1)
        loss_d_real = criterion(pred,lable)

        z = torch.randn(img.shape[0],ch_z,device=device)
        fake_img = G(z)
        lable = torch.ones(bs,device=device)*0.1
        pred = D(fake_img.detach()).reshape(-1)
        loss_d_fake = criterion(pred,lable)
        D_loss = loss_d_real + loss_d_fake
        D_loss.backward()
        optim_d.step()
        ## Generator ##
        G.zero_grad(True)
        lable = torch.randn(bs,device=device)
        pred = D(fake_img).reshape(-1)
        G_loss = criterion(pred,lable)
        G_loss.backward()
        optim_g.step()
        ## printing on terminal
        if idx % 100 == 0:
            print(f'\nBatches done : {idx}/{len(mnist)}')
            print(f'Loss_D : {D_loss.item():.4f}\tLoss_G : {G_loss.item():.4f}')
    et = time.time()
    print(f'\nEpoch : {i+1}\n{time_cal(st,et)}')
    G.eval()
    with torch.no_grad():
        fake_image = G(fixed_noise)
        save_image(fake_image[:25],fp=f'{path_to_img}/{i+1}_fake.png',nrow=5,normalize=True)

    loss_G.append(G_loss.item())
    loss_D.append(D_loss.item())

Here's the link to my Colab notebook in which I have been working.

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...