Ответ 1
Предлагаемое решение
Повторное использование кода из репозитория, который вы поделили, вот некоторые предлагаемые модификации для обучения классификатора вдоль вашего генератора и дискриминатора (их архитектуры и другие потери остаются нетронутыми):
from keras import backend as K
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
def lenet_classifier_model(nb_classes):
# Snipped by Fabien Tanc - https://www.kaggle.com/ftence/keras-cnn-inspired-by-lenet-5
# Replace with your favorite classifier...
model = Sequential()
model.add(Convolution2D(12, 5, 5, activation='relu', input_shape=in_shape, init='he_normal'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(25, 5, 5, activation='relu', init='he_normal'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(180, activation='relu', init='he_normal'))
model.add(Dropout(0.5))
model.add(Dense(100, activation='relu', init='he_normal'))
model.add(Dropout(0.5))
model.add(Dense(nb_classes, activation='softmax', init='he_normal'))
def generator_containing_discriminator_and_classifier(generator, discriminator, classifier):
inputs = Input((IN_CH, img_cols, img_rows))
x_generator = generator(inputs)
merged = merge([inputs, x_generator], mode='concat', concat_axis=1)
discriminator.trainable = False
x_discriminator = discriminator(merged)
classifier.trainable = False
x_classifier = classifier(x_generator)
model = Model(input=inputs, output=[x_generator, x_discriminator, x_classifier])
return model
def train(BATCH_SIZE):
(X_train, Y_train, LABEL_train) = get_data('train') # replace with your data here
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
Y_train = (Y_train.astype(np.float32) - 127.5) / 127.5
discriminator = discriminator_model()
generator = generator_model()
classifier = lenet_classifier_model(6)
generator.summary()
discriminator_and_classifier_on_generator = generator_containing_discriminator_and_classifier(
generator, discriminator, classifier)
d_optim = Adagrad(lr=0.005)
g_optim = Adagrad(lr=0.005)
generator.compile(loss='mse', optimizer="rmsprop")
discriminator_and_classifier_on_generator.compile(
loss=[generator_l1_loss, discriminator_on_generator_loss, "categorical_crossentropy"],
optimizer="rmsprop")
discriminator.trainable = True
discriminator.compile(loss=discriminator_loss, optimizer="rmsprop")
classifier.trainable = True
classifier.compile(loss="categorical_crossentropy", optimizer="rmsprop")
for epoch in range(100):
print("Epoch is", epoch)
print("Number of batches", int(X_train.shape[0] / BATCH_SIZE))
for index in range(int(X_train.shape[0] / BATCH_SIZE)):
image_batch = Y_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
label_batch = LABEL_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE] # replace with your data here
generated_images = generator.predict(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE])
if index % 20 == 0:
image = combine_images(generated_images)
image = image * 127.5 + 127.5
image = np.swapaxes(image, 0, 2)
cv2.imwrite(str(epoch) + "_" + str(index) + ".png", image)
# Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")
# Training D:
real_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], image_batch),
axis=1)
fake_pairs = np.concatenate(
(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], generated_images), axis=1)
X = np.concatenate((real_pairs, fake_pairs))
y = np.zeros((20, 1, 64, 64)) # [1] * BATCH_SIZE + [0] * BATCH_SIZE
d_loss = discriminator.train_on_batch(X, y)
print("batch %d d_loss : %f" % (index, d_loss))
discriminator.trainable = False
# Training C:
c_loss = classifier.train_on_batch(image_batch, label_batch)
print("batch %d c_loss : %f" % (index, c_loss))
classifier.trainable = False
# Train G:
g_loss = discriminator_and_classifier_on_generator.train_on_batch(
X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :],
[image_batch, np.ones((10, 1, 64, 64)), label_batch])
discriminator.trainable = True
classifier.trainable = True
print("batch %d g_loss : %f" % (index, g_loss[1]))
if index % 20 == 0:
generator.save_weights('generator', True)
discriminator.save_weights('discriminator', True)
Теоретические детали
Я считаю, что есть некоторые недоразумения относительно того, как работают условные ГАН и какова роль дискриминаторов в таких схемах.
Роль Дискриминатора
В игре min-max, которая является тренировкой GAN [4], дискриминатор D
играет против генератора G
(сети, на которой вы действительно заботитесь), так что под D
контролем G
становится лучше выводить реалистичные результаты.
Для этого D
обучается отличать реальные образцы от образцов из G
; в то время как G
обучается дурачить D
, генерируя реалистичные результаты/результаты, следующие за целевым распределением.
Примечание: в случае условных GAN, то есть GAN, отображающих входную выборку из одного домена
A
(например, реального изображения) в другой доменB
(например, эскиз),D
обычно питается парами образцов, сложенными вместе и должен различать "реальные "пары (входной образец изA
+ соответствующего целевого образца изB
) и" поддельные "пары (входной образец изA
+ соответствующего выхода изG
) [1, 2]
Обучение условного генератора против D
(в отличие от простого обучения G
только с потерей L1/L2, например, DAE) улучшает возможности выборки G
, заставляя его выводить четкие, реалистичные результаты, а не пытаться усреднить распределение.
Несмотря на то, что дискриминаторы могут иметь несколько подсетей для покрытия других задач (см. Следующие параграфы), D
должен поддерживать по крайней мере одну подсетеву/вывод для своей основной задачи: рассказывать реальные образцы из сгенерированных. Просить D
повторить дальнейшую семантическую информацию (например, классы) рядом может помешать этой основной цели.
Примечание: вывод
D
часто не является простым скаляром/булевым. Обычно существует дискриминатор (например, PatchGAN [1, 2]), возвращающий матрицу вероятностей, оценивая, насколько реалистичные исправления сделаны из его ввода.
Условные GAN
Традиционные ГАН обучаются неконтролируемым образом для создания реалистичных данных (например, изображений) из случайного шума в качестве входных данных. [4]
Как упоминалось ранее, условные GAN имеют дополнительные условия ввода. Вдоль/вместо вектора шума они берут для ввода образца из области A
и возвращают соответствующий образец из области B
A
может быть совершенно другой модальности, например B = sketch image
тогда как A = discrete label
; B = volumetric data
то время как A = RGB image
и т.д. [3]
Такие GAN также могут быть обусловлены многократными входами, например A = real image + discrete label
тогда как B = sketch image
. Известная работа по внедрению таких методов - InfoGAN [5]. В нем описывается, как согласовать GAN на нескольких непрерывных или дискретных входах (например, A = digit class + writing type
, B = handwritten digit image
), используя более продвинутый дискриминатор, который для второй задачи должен заставить G
максимизировать взаимную информацию между ее и его соответствующие выходы.
Максимизация взаимной информации для cGANs
У дискриминатора InfoGAN есть 2 главы/подсерии для выполнения двух задач [5]:
- Одна голова
D1
выполняет традиционную реальную/сгенерированную дискриминацию -G
должен минимизировать этот результат, т.е. Он должен обманутьD1
чтобы он не мог отличить данные, генерируемые реальной формой; - Другая головка
D2
(также называемаяQ
сетью) пытается регрессировать входную информациюA
-G
должна максимизировать этот результат, то есть она должна выводить данные, которые "показывают" запрошенную семантическую информацию (см. максимизация взаимной информации междуG
условными входами и его результаты).
Вы можете найти здесь реализацию Keras: https://github.com/eriklindernoren/Keras-GAN/tree/master/infogan.
В нескольких работах используются аналогичные схемы для улучшения контроля над тем, что генерирует GAN, используя предоставленные метки и максимизируя взаимную информацию между этими входами и выходами G
[6, 7]. Основная идея всегда одна и та же:
- Поезд
G
для генерации элементов доменаB
, учитывая некоторые входы домена (ов)A
; - Поезд
D
для распознавания "реальных"/"поддельных" результатов -G
должен свести к минимуму это; - Поезд
Q
(например, классификатор, может делиться слоями сD
), чтобы оценить исходные входыA
изB
выборок -G
должен максимизировать это).
Завершение
В вашем случае, кажется, у вас есть следующие данные обучения:
- реальные изображения
Ia
- соответствующие эскизные изображения
Ib
- соответствующие метки классов
c
И вы хотите обучить генератор G
так, чтобы при заданном изображении Ia
и его метке класса c
он выводил изображение эскиза Ib'
.
В общем, у вас много информации, и вы можете контролировать свое обучение как на условных изображениях, так и на условных ярлыках... Вдохновленный вышеупомянутыми методами [1, 2, 5, 6, 7], здесь возможный способ использования всей этой информации для обучения вашего условного G
:
G
: - Входы:
Ia
+c
- Выход:
Ib'
- Архитектура: до вас (например, U-Net, ResNet,...)
- Потери: потери L1/L2 между
Ib'
&Ib
,-D
убыток, потеряQ
D
: - Входы:
Ia
+Ib
(действительная пара),Ia
+Ib'
(поддельная пара) - Результат: скалярный/матричный скачок
- Архитектура: до вас (например, PatchGAN)
- Потеря: кросс-энтропия по оценке "подлости"
Q
: - Входы:
Ib
(реальный образец, для обученияQ
),Ib'
(поддельный образец, при обратном распространении черезG
) - Результат:
c'
(оценочный класс) - Архитектура: до вас (например, LeNet, ResNet, VGG,...)
- Потеря: кросс-энтропия между
c
иc'
- Поезд
D
на партию реальных парIa
+Ib
затем на партию поддельных парIa
+Ib'
; - Поезд
Q
на партии реальных образцовIb
; - Исправить
D
иQ
веса; - Поезд
G
, передавая свои сгенерированные выходыIb'
вD
иQ
для распространения через них.
Примечание. Это действительно грубое описание архитектуры. Я бы рекомендовал пройти литературу ([1, 5, 6, 7] как хороший старт), чтобы получить более подробную информацию и, возможно, более сложное решение.
Рекомендации
- Изола, Филлип и др. "Преобразование изображения в изображение с условными состязательными сетями". Препринт arXiv (2017). http://openaccess.thecvf.com/content_cvpr_2017/papers/Isola_Image-To-Image_Translation_With_CVPR_2017_paper.pdf
- Zhu, Jun-Yan, et al. "Непарный перевод изображения в изображение с использованием согласованных по последовательному сценарию состязательных сетей". arXiv preprint arXiv: 1703.10593 (2017). http://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf
- Мирза, Мехди и Симон Осиндеро. "Условные генеративные состязательные сети". arXiv preprint arXiv: 1411.1784 (2014). https://arxiv.org/pdf/1411.1784
- Goodfellow, Ian, et al. "Генеративные состязательные сети". Достижения в системах обработки нейронной информации. 2014. http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
- Chen, Xi, et al. "Infogan: Интерпретируемое представление обучения посредством информации, максимизирующей генеративные состязательные сети". Достижения в нейронных системах обработки информации. 2016. http://papers.nips.cc/paper/6399-infogan-interpretable-representation-learning-by-information-maximizing-generative-adversarial-nets.pdf
- Ли, Минхёк и Джунхи Сёк. "Управляемая генерирующая сеть Adversarial". arXiv preprint arXiv: 1708.00598 (2017). https://arxiv.org/pdf/1708.00598.pdf
- Одена, Август, Кристофер Ола и Джонатон Шленс. "Синтез условного изображения с помощью вспомогательных классификаторов". arXiv preprint arXiv: 1610.09585 (2016). http://proceedings.mlr.press/v70/odena17a/odena17a.pdf