Создание новых изображений с помощью PyTorch

Я изучаю GAN. Я закончил один курс, который дал мне пример программы, которая генерирует изображения на основе введенных примеров.

Пример можно найти здесь:

https://github.com/davidsonmizael/gan

Поэтому я решил использовать это для создания новых изображений на основе набора данных фронтальных фотографий лиц, но я не добился успеха. В отличие от приведенного выше примера, код генерирует только шум, а на входе - фактические изображения.

На самом деле, я не имею ни малейшего представления о том, что я должен изменить, чтобы код указывал в правильном направлении и учился на изображениях. Я не изменяю ни одного значения кода, представленного в примере, но он не работает.

Если кто-нибудь может помочь мне понять это и указать мне в правильном направлении, это будет очень полезно. Спасибо заранее.

Мой Дискриминатор:

class D(nn.Module):

    def __init__(self):
        super(D, self).__init__()
        self.main = nn.Sequential(
                nn.Conv2d(3, 64, 4, 2, 1, bias = False),
                nn.LeakyReLU(0.2, inplace = True),
                nn.Conv2d(64, 128, 4, 2, 1, bias = False),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.2, inplace = True),
                nn.Conv2d(128, 256, 4, 2, 1, bias = False),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.2, inplace = True),
                nn.Conv2d(256, 512, 4, 2, 1, bias = False),
                nn.BatchNorm2d(512),
                nn.LeakyReLU(0.2, inplace = True),
                nn.Conv2d(512, 1, 4, 1, 0, bias = False),
                nn.Sigmoid()
                )

    def forward(self, input):
        return self.main(input).view(-1)

Мой генератор:

class G(nn.Module):

    def __init__(self):
        super(G, self).__init__()
        self.main = nn.Sequential(
                nn.ConvTranspose2d(100, 512, 4, 1, 0, bias = False),
                nn.BatchNorm2d(512),
                nn.ReLU(True),
                nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False),
                nn.BatchNorm2d(256),
                nn.ReLU(True),
                nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False),
                nn.BatchNorm2d(128),
                nn.ReLU(True),
                nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False),
                nn.BatchNorm2d(64),
                nn.ReLU(True),
                nn.ConvTranspose2d(64, 3, 4, 2, 1, bias = False),
                nn.Tanh()
                )

    def forward(self, input):
        return self.main(input)

Моя функция для запуска весов:

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

Полный код можно увидеть здесь:

https://github.com/davidsonmizael/criminal-gan

Шум, генерируемый в эпоху номер 25: Шум, созданный в эпоху номер 25

Ввод с реальными изображениями: Ввод с реальными изображениями.

Ответы

Ответ 1

Код из вашего примера (https://github.com/davidsonmizael/gan) дал мне тот же шум, что и вы. Потеря генератора слишком быстро снизилась.

Было несколько ошибок, я даже не уверен, что - но я думаю, что легко разобраться в различиях. Для сравнения также рассмотрим этот учебник: GAN в 50 строках PyTorch

.... same as your code
print("# Starting generator and descriminator...")
netG = G()
netG.apply(weights_init)

netD = D()
netD.apply(weights_init)

if torch.cuda.is_available():
    netG.cuda()
    netD.cuda()

#training the DCGANs
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))

epochs = 25

timeElapsed = []
for epoch in range(epochs):
    print("# Starting epoch [%d/%d]..." % (epoch, epochs))
    for i, data in enumerate(dataloader, 0):
        start = time.time()
        time.clock()  

        #updates the weights of the discriminator nn
        netD.zero_grad()

        #trains the discriminator with a real image
        real, _ = data

        if torch.cuda.is_available():
            inputs = Variable(real.cuda()).cuda()
            target = Variable(torch.ones(inputs.size()[0]).cuda()).cuda()
        else:
            inputs = Variable(real)
            target = Variable(torch.ones(inputs.size()[0]))

        output = netD(inputs)
        errD_real = criterion(output, target)
        errD_real.backward() #retain_graph=True

        #trains the discriminator with a fake image
        if torch.cuda.is_available():
            D_noise = Variable(torch.randn(inputs.size()[0], 100, 1, 1).cuda()).cuda()
            target = Variable(torch.zeros(inputs.size()[0]).cuda()).cuda()
        else:
            D_noise = Variable(torch.randn(inputs.size()[0], 100, 1, 1))
            target = Variable(torch.zeros(inputs.size()[0]))
        D_fake = netG(D_noise).detach()
        D_fake_ouput = netD(D_fake)
        errD_fake = criterion(D_fake_ouput, target)
        errD_fake.backward()

        # NOT:backpropagating the total error
        # errD = errD_real + errD_fake

        optimizerD.step()

    #for i, data in enumerate(dataloader, 0):

        #updates the weights of the generator nn
        netG.zero_grad()

        if torch.cuda.is_available():
            G_noise = Variable(torch.randn(inputs.size()[0], 100, 1, 1).cuda()).cuda()
            target = Variable(torch.ones(inputs.size()[0]).cuda()).cuda()
        else:
            G_noise = Variable(torch.randn(inputs.size()[0], 100, 1, 1))
            target = Variable(torch.ones(inputs.size()[0]))

        fake = netG(G_noise)
        G_output = netD(fake)
        errG  = criterion(G_output, target)

        #backpropagating the error
        errG.backward()
        optimizerG.step()


        if i % 50 == 0:
            #prints the losses and save the real images and the generated images
            print("# Progress: ")
            print("[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f" % (epoch, epochs, i, len(dataloader), errD_real.data[0], errG.data[0]))

            #calculates the remaining time by taking the avg seconds that every loop
            #and multiplying by the loops that still need to run
            timeElapsed.append(time.time() - start)
            avg_time = (sum(timeElapsed) / float(len(timeElapsed)))
            all_dtl = (epoch * len(dataloader)) + i
            rem_dtl = (len(dataloader) - i) + ((epochs - epoch) * len(dataloader))
            remaining =  (all_dtl - rem_dtl) * avg_time
            print("# Estimated remaining time: %s" % (time.strftime("%H:%M:%S", time.gmtime(remaining))))

        if i % 100 == 0:
            vutils.save_image(real, "%s/real_samples.png" % "./results", normalize = True)
            vutils.save_image(fake.data, "%s/fake_samples_epoch_%03d.png" % ("./results", epoch), normalize = True)

print ("# Finished.")

Результат после 25 эпох (batchsize 256) на CIFAR-10: введите описание изображения здесь

Ответ 2

Обучение GAN происходит не очень быстро. Я предполагаю, что вы не используете предварительно подготовленную модель, но учитесь с нуля. В эпоху 25 вполне нормально не видеть каких-либо значимых паттернов в образцах. Я понимаю, что проект github показывает вам что-то прохладное после 25 эпох, но это также зависит от размера набора данных. CIFAR-10 (тот, который использовался на странице github) имеет 60000 изображений. 25 эпох означает, что сеть увидела все 25 раз.

Я не знаю, какой набор данных вы используете, но если он меньше, может потребоваться больше эпох, пока вы не увидите результаты, потому что сеть получает меньше изображений в целом. Если изображения в вашем наборе данных имеют более высокое разрешение, это может также занять больше времени.

Вы должны проверить снова, по крайней мере, несколько сотен, если не несколько тысяч эпох.


например. на фронтальном лицевом наборе фотографий после 25 эпох: введите описание изображения здесь

И через 50 эпох: введите описание изображения здесь