Estoy tratando de crear un modelo GAN. Este es mi discriminador.py
import torch.nn as nn class D(nn.Module): feature_maps = 64 kernel_size = 4 stride = 2 padding = 1 bias = False inplace = True def __init__(self): super(D, self).__init__() self.main = nn.Sequential( nn.Conv2d(4, self.feature_maps, self.kernel_size, self.stride, self.padding, bias=self.bias), nn.LeakyReLU(0.2, inplace=self.inplace), nn.Conv2d(self.feature_maps, self.feature_maps * 2, self.kernel_size, self.stride, self.padding, bias=self.bias), nn.BatchNorm2d(self.feature_maps * 2), nn.LeakyReLU(0.2, inplace=self.inplace), nn.Conv2d(self.feature_maps * 2, self.feature_maps * (2 * 2), self.kernel_size, self.stride, self.padding, bias=self.bias), nn.BatchNorm2d(self.feature_maps * (2 * 2)), nn.LeakyReLU(0.2, inplace=self.inplace), nn.Conv2d(self.feature_maps * (2 * 2), self.feature_maps * (2 * 2 * 2), self.kernel_size, self.stride, self.padding, bias=self.bias), nn.BatchNorm2d(self.feature_maps * (2 * 2 * 2)), nn.LeakyReLU(0.2, inplace=self.inplace), nn.Conv2d(self.feature_maps * (2 * 2 * 2), 1, self.kernel_size, 1, 0, bias=self.bias), nn.Sigmoid() ) def forward(self, input): output = self.main(input) return output.view(-1)
este es mi generador.py
import torch.nn as nn class G(nn.Module): feature_maps = 512 kernel_size = 4 stride = 2 padding = 1 bias = False def __init__(self, input_vector): super(G, self).__init__() self.main = nn.Sequential( nn.ConvTranspose2d(input_vector, self.feature_maps, self.kernel_size, 1, 0, bias=self.bias), nn.BatchNorm2d(self.feature_maps), nn.ReLU(True), nn.ConvTranspose2d(self.feature_maps, int(self.feature_maps // 2), self.kernel_size, self.stride, self.padding, bias=self.bias), nn.BatchNorm2d(int(self.feature_maps // 2)), nn.ReLU(True), nn.ConvTranspose2d(int(self.feature_maps // 2), int((self.feature_maps // 2) // 2), self.kernel_size, self.stride, self.padding, bias=self.bias), nn.BatchNorm2d(int((self.feature_maps // 2) // 2)), nn.ReLU(True), nn.ConvTranspose2d((int((self.feature_maps // 2) // 2)), int(((self.feature_maps // 2) // 2) // 2), self.kernel_size, self.stride, self.padding, bias=self.bias), nn.BatchNorm2d(int((self.feature_maps // 2) // 2) // 2), nn.ReLU(True), nn.ConvTranspose2d(int(((self.feature_maps // 2) // 2) // 2), 4, self.kernel_size, self.stride, self.padding, bias=self.bias), nn.Tanh() ) def forward(self, input): output = self.main(input) return output
Este es mi gans.py
# Importing the libraries from __future__ import print_function import torch.nn as nn import torch.optim as optim import torch.utils.data import torchvision.datasets as dset import torchvision.transforms as transforms import torchvision.utils as vutils from torch.autograd import Variable from generator import G from discriminator import D import os from PIL import Image batchSize = 64 # We set the size of the batch. imageSize = 64 # We set the size of the generated images (64x64). input_vector = 100 nb_epochs = 500 # Creating the transformations transform = transforms.Compose([transforms.Resize((imageSize, imageSize)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5, 0.5), (0.5, 0.5, 0.5, 0.5)), ]) # We create a list of transformations (scaling, tensor conversion, normalization) to apply to the input images. def pil_loader_rgba(path: str) -> Image.Image: with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGBA') # Loading the dataset dataset = dset.ImageFolder(root='./data', transform=transform, loader=pil_loader_rgba) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True, num_workers=2) # We use dataLoader to get the images of the training set batch by batch. # Defining the weights_init function that takes as input a neural network m and that will initialize all its weights. def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) def is_cuda_available(): return torch.cuda.is_available() def is_gpu_available(): if is_cuda_available(): if int(torch.cuda.device_count()) > 0: return True return False return False # Create results directory def create_dir(name): if not os.path.exists(name): os.makedirs(name) # Creating the generator netG = G(input_vector) netG.apply(weights_init) # Creating the discriminator netD = D() netD.apply(weights_init) if is_gpu_available(): netG.cuda() netD.cuda() # Training the DCGANs criterion = nn.BCELoss() optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) generator_model = 'generator_model' discriminator_model = 'discriminator_model' def save_model(epoch, model, optimizer, error, filepath, noise=None): if os.path.exists(filepath): os.remove(filepath) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': error, 'noise': noise }, filepath) def load_checkpoint(filepath): if os.path.exists(filepath): return torch.load(filepath) return None def main(): print("Device name : " + torch.cuda.get_device_name(0)) for epoch in range(nb_epochs): for i, data in enumerate(dataloader, 0): checkpointG = load_checkpoint(generator_model) checkpointD = load_checkpoint(discriminator_model) if checkpointG: netG.load_state_dict(checkpointG['model_state_dict']) optimizerG.load_state_dict(checkpointG['optimizer_state_dict']) if checkpointD: netD.load_state_dict(checkpointD['model_state_dict']) optimizerD.load_state_dict(checkpointD['optimizer_state_dict']) # 1st Step: Updating the weights of the neural network of the discriminator netD.zero_grad() # Training the discriminator with a real image of the dataset real, _ = data if is_gpu_available(): input = Variable(real.cuda()).cuda() target = Variable(torch.ones(input.size()[0]).cuda()).cuda() else: input = Variable(real) target = Variable(torch.ones(input.size()[0])) output = netD(input) errD_real = criterion(output, target) # Training the discriminator with a fake image generated by the generator if is_gpu_available(): noise = Variable(torch.randn(input.size()[0], input_vector, 1, 1)).cuda() target = Variable(torch.zeros(input.size()[0])).cuda() else: noise = Variable(torch.randn(input.size()[0], input_vector, 1, 1)) target = Variable(torch.zeros(input.size()[0])) fake = netG(noise) output = netD(fake.detach()) errD_fake = criterion(output, target) # Backpropagating the total error errD = errD_real + errD_fake errD.backward() optimizerD.step() # 2nd Step: Updating the weights of the neural network of the generator netG.zero_grad() if is_gpu_available(): target = Variable(torch.ones(input.size()[0])).cuda() else: target = Variable(torch.ones(input.size()[0])) output = netD(fake) errG = criterion(output, target) errG.backward() optimizerG.step() # 3rd Step: Printing the losses and saving the real images and the generated images of the minibatch every 100 steps print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % ( epoch, nb_epochs, i, len(dataloader), errD.data, errG.data)) save_model(epoch, netG, optimizerG, errG, generator_model, noise) save_model(epoch, netD, optimizerD, errD, discriminator_model, noise) if i % 100 == 0: create_dir('results') vutils.save_image(real, '%s/real_samples.png' % "./results", normalize=True) fake = netG(noise) vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch), normalize=True) if __name__ == "__main__": main()
Así que DESPUÉS de unas horas decidí mirar mi carpeta de resultados. Vi cosas extrañas DESPUÉS de 39th epoch. El generador comenzó a generar las peores imágenes. Hasta el generador de época 39 MEJORADO. Por favor, mire la siguiente captura de pantalla.
¿Por qué el generador de repente se volvió peor? Estoy tratando de ejecutar 500 épocas. Pensé que más épocas más éxito
Así que eché un vistazo a los registros y veo a continuación
[40/500][0/157] Loss_D: 0.0141 Loss_G: 5.7559 [40/500][1/157] Loss_D: 0.0438 Loss_G: 5.5805 [40/500][2/157] Loss_D: 0.0161 Loss_G: 6.4947 [40/500][3/157] Loss_D: 0.0138 Loss_G: 7.1711 [40/500][4/157] Loss_D: 0.0547 Loss_G: 4.6262 [40/500][5/157] Loss_D: 0.0295 Loss_G: 4.7831 [40/500][6/157] Loss_D: 0.0103 Loss_G: 6.3700 [40/500][7/157] Loss_D: 0.0276 Loss_G: 5.9162 [40/500][8/157] Loss_D: 0.0205 Loss_G: 6.3571 [40/500][9/157] Loss_D: 0.0139 Loss_G: 6.4961 [40/500][10/157] Loss_D: 0.0117 Loss_G: 6.4371 [40/500][11/157] Loss_D: 0.0057 Loss_G: 6.6858 [40/500][12/157] Loss_D: 0.0203 Loss_G: 5.4308 [40/500][13/157] Loss_D: 0.0078 Loss_G: 6.5749 [40/500][14/157] Loss_D: 0.0115 Loss_G: 6.3202 [40/500][15/157] Loss_D: 0.0187 Loss_G: 6.2258 [40/500][16/157] Loss_D: 0.0052 Loss_G: 6.5253 [40/500][17/157] Loss_D: 0.0158 Loss_G: 5.5672 [40/500][18/157] Loss_D: 0.0156 Loss_G: 5.5416 [40/500][19/157] Loss_D: 0.0306 Loss_G: 5.4550 [40/500][20/157] Loss_D: 0.0077 Loss_G: 6.1985 [40/500][21/157] Loss_D: 0.0158 Loss_G: 5.3092 [40/500][22/157] Loss_D: 0.0167 Loss_G: 5.8395 [40/500][23/157] Loss_D: 0.0119 Loss_G: 6.0849 [40/500][24/157] Loss_D: 0.0104 Loss_G: 6.5493 [40/500][25/157] Loss_D: 0.0182 Loss_G: 5.6758 [40/500][26/157] Loss_D: 0.0145 Loss_G: 5.8336 [40/500][27/157] Loss_D: 0.0050 Loss_G: 6.8472 [40/500][28/157] Loss_D: 0.0080 Loss_G: 6.4894 [40/500][29/157] Loss_D: 0.0186 Loss_G: 5.5563 [40/500][30/157] Loss_D: 0.0143 Loss_G: 6.4144 [40/500][31/157] Loss_D: 0.0377 Loss_G: 5.4557 [40/500][32/157] Loss_D: 0.0540 Loss_G: 4.6034 [40/500][33/157] Loss_D: 0.0200 Loss_G: 5.6417 [40/500][34/157] Loss_D: 0.0189 Loss_G: 5.7760 [40/500][35/157] Loss_D: 0.0197 Loss_G: 6.1732 [40/500][36/157] Loss_D: 0.0093 Loss_G: 6.4046 [40/500][37/157] Loss_D: 0.0281 Loss_G: 5.5217 [40/500][38/157] Loss_D: 0.0410 Loss_G: 5.9157 [40/500][39/157] Loss_D: 0.0667 Loss_G: 5.2522 [40/500][40/157] Loss_D: 0.0530 Loss_G: 5.6412 [40/500][41/157] Loss_D: 0.0315 Loss_G: 5.9325 [40/500][42/157] Loss_D: 0.0097 Loss_G: 6.7819 [40/500][43/157] Loss_D: 0.0157 Loss_G: 5.8630 [40/500][44/157] Loss_D: 0.0382 Loss_G: 5.1942 [40/500][45/157] Loss_D: 0.0331 Loss_G: 5.1490 [40/500][46/157] Loss_D: 0.0362 Loss_G: 5.7026 [40/500][47/157] Loss_D: 0.0237 Loss_G: 5.7493 [40/500][48/157] Loss_D: 0.0227 Loss_G: 5.7636 [40/500][49/157] Loss_D: 0.0230 Loss_G: 5.6500 [40/500][50/157] Loss_D: 0.0329 Loss_G: 5.4542 [40/500][51/157] Loss_D: 0.0306 Loss_G: 5.6473 [40/500][52/157] Loss_D: 0.0254 Loss_G: 5.8464 [40/500][53/157] Loss_D: 0.0402 Loss_G: 5.8609 [40/500][54/157] Loss_D: 0.0242 Loss_G: 5.9952 [40/500][55/157] Loss_D: 0.0400 Loss_G: 5.8378 [40/500][56/157] Loss_D: 0.0302 Loss_G: 5.8990 [40/500][57/157] Loss_D: 0.0239 Loss_G: 5.8134 [40/500][58/157] Loss_D: 0.0348 Loss_G: 5.8109 [40/500][59/157] Loss_D: 0.0361 Loss_G: 5.9011 [40/500][60/157] Loss_D: 0.0418 Loss_G: 5.8825 [40/500][61/157] Loss_D: 0.0501 Loss_G: 6.2302 [40/500][62/157] Loss_D: 0.0184 Loss_G: 6.2755 [40/500][63/157] Loss_D: 0.0273 Loss_G: 5.9655 [40/500][64/157] Loss_D: 0.0250 Loss_G: 5.7513 [40/500][65/157] Loss_D: 0.0298 Loss_G: 6.0434 [40/500][66/157] Loss_D: 0.0299 Loss_G: 6.4280 [40/500][67/157] Loss_D: 0.0205 Loss_G: 6.3743 [40/500][68/157] Loss_D: 0.0173 Loss_G: 6.2749 [40/500][69/157] Loss_D: 0.0199 Loss_G: 6.0541 [40/500][70/157] Loss_D: 0.0309 Loss_G: 6.5044 [40/500][71/157] Loss_D: 0.0177 Loss_G: 6.6093 [40/500][72/157] Loss_D: 0.0363 Loss_G: 7.2993 [40/500][73/157] Loss_D: 0.0093 Loss_G: 7.6995 [40/500][74/157] Loss_D: 0.0087 Loss_G: 7.3493 [40/500][75/157] Loss_D: 0.0540 Loss_G: 8.2688 [40/500][76/157] Loss_D: 0.0172 Loss_G: 8.3312 [40/500][77/157] Loss_D: 0.0086 Loss_G: 7.6863 [40/500][78/157] Loss_D: 0.0232 Loss_G: 7.4930 [40/500][79/157] Loss_D: 0.0175 Loss_G: 7.8834 [40/500][80/157] Loss_D: 0.0109 Loss_G: 9.5329 [40/500][81/157] Loss_D: 0.0093 Loss_G: 7.3253 [40/500][82/157] Loss_D: 0.0674 Loss_G: 10.6709 [40/500][83/157] Loss_D: 0.0010 Loss_G: 10.8321 [40/500][84/157] Loss_D: 0.0083 Loss_G: 8.5728 [40/500][85/157] Loss_D: 0.0124 Loss_G: 6.9085 [40/500][86/157] Loss_D: 0.0181 Loss_G: 7.0867 [40/500][87/157] Loss_D: 0.0130 Loss_G: 7.3527 [40/500][88/157] Loss_D: 0.0189 Loss_G: 7.2494 [40/500][89/157] Loss_D: 0.0302 Loss_G: 8.7555 [40/500][90/157] Loss_D: 0.0147 Loss_G: 7.7668 [40/500][91/157] Loss_D: 0.0325 Loss_G: 7.7779 [40/500][92/157] Loss_D: 0.0257 Loss_G: 8.3955 [40/500][93/157] Loss_D: 0.0113 Loss_G: 8.3687 [40/500][94/157] Loss_D: 0.0124 Loss_G: 7.6081 [40/500][95/157] Loss_D: 0.0088 Loss_G: 7.6012 [40/500][96/157] Loss_D: 0.0241 Loss_G: 7.6573 [40/500][97/157] Loss_D: 0.0522 Loss_G: 10.8114 [40/500][98/157] Loss_D: 0.0071 Loss_G: 11.0529 [40/500][99/157] Loss_D: 0.0043 Loss_G: 8.0707 [40/500][100/157] Loss_D: 0.0141 Loss_G: 7.2864 [40/500][101/157] Loss_D: 0.0234 Loss_G: 7.3585 [40/500][102/157] Loss_D: 0.0148 Loss_G: 7.4577 [40/500][103/157] Loss_D: 0.0190 Loss_G: 8.1904 [40/500][104/157] Loss_D: 0.0201 Loss_G: 8.1518 [40/500][105/157] Loss_D: 0.0220 Loss_G: 9.1069 [40/500][106/157] Loss_D: 0.0108 Loss_G: 9.0069 [40/500][107/157] Loss_D: 0.0044 Loss_G: 8.0970 [40/500][108/157] Loss_D: 0.0076 Loss_G: 7.2699 [40/500][109/157] Loss_D: 0.0052 Loss_G: 7.4036 [40/500][110/157] Loss_D: 0.0167 Loss_G: 7.2742 [40/500][111/157] Loss_D: 0.0032 Loss_G: 7.9825 [40/500][112/157] Loss_D: 0.3462 Loss_G: 32.6314 [40/500][113/157] Loss_D: 0.1704 Loss_G: 40.6010 [40/500][114/157] Loss_D: 0.0065 Loss_G: 44.4607 [40/500][115/157] Loss_D: 0.0142 Loss_G: 43.9761 [40/500][116/157] Loss_D: 0.0160 Loss_G: 45.0376 [40/500][117/157] Loss_D: 0.0042 Loss_G: 45.9534 [40/500][118/157] Loss_D: 0.0061 Loss_G: 45.2998 [40/500][119/157] Loss_D: 0.0023 Loss_G: 45.4654 [40/500][120/157] Loss_D: 0.0033 Loss_G: 44.6643 [40/500][121/157] Loss_D: 0.0042 Loss_G: 44.6020 [40/500][122/157] Loss_D: 0.0002 Loss_G: 44.4807 [40/500][123/157] Loss_D: 0.0004 Loss_G: 44.0402 [40/500][124/157] Loss_D: 0.0055 Loss_G: 43.9188 [40/500][125/157] Loss_D: 0.0021 Loss_G: 43.1988 [40/500][126/157] Loss_D: 0.0008 Loss_G: 41.6770 [40/500][127/157] Loss_D: 0.0001 Loss_G: 40.8719 [40/500][128/157] Loss_D: 0.0009 Loss_G: 40.3803 [40/500][129/157] Loss_D: 0.0023 Loss_G: 39.0143 [40/500][130/157] Loss_D: 0.0254 Loss_G: 39.0317 [40/500][131/157] Loss_D: 0.0008 Loss_G: 37.9451 [40/500][132/157] Loss_D: 0.0253 Loss_G: 37.1046 [40/500][133/157] Loss_D: 0.0046 Loss_G: 36.2807 [40/500][134/157] Loss_D: 0.0025 Loss_G: 35.5878 [40/500][135/157] Loss_D: 0.0011 Loss_G: 33.6500 [40/500][136/157] Loss_D: 0.0061 Loss_G: 33.5011 [40/500][137/157] Loss_D: 0.0015 Loss_G: 30.0363 [40/500][138/157] Loss_D: 0.0019 Loss_G: 31.0197 [40/500][139/157] Loss_D: 0.0027 Loss_G: 28.4693 [40/500][140/157] Loss_D: 0.0189 Loss_G: 27.3072 [40/500][141/157] Loss_D: 0.0051 Loss_G: 26.6637 [40/500][142/157] Loss_D: 0.0077 Loss_G: 24.8390 [40/500][143/157] Loss_D: 0.0123 Loss_G: 23.8334 [40/500][144/157] Loss_D: 0.0014 Loss_G: 23.3755 [40/500][145/157] Loss_D: 0.0036 Loss_G: 19.6341 [40/500][146/157] Loss_D: 0.0025 Loss_G: 18.1076 [40/500][147/157] Loss_D: 0.0029 Loss_G: 16.9415 [40/500][148/157] Loss_D: 0.0028 Loss_G: 16.4647 [40/500][149/157] Loss_D: 0.0048 Loss_G: 14.6184 [40/500][150/157] Loss_D: 0.0074 Loss_G: 13.2544 [40/500][151/157] Loss_D: 0.0053 Loss_G: 13.0052 [40/500][152/157] Loss_D: 0.0070 Loss_G: 11.8815 [40/500][153/157] Loss_D: 0.0078 Loss_G: 12.1657 [40/500][154/157] Loss_D: 0.0094 Loss_G: 10.4259 [40/500][155/157] Loss_D: 0.0073 Loss_G: 9.9345 [40/500][156/157] Loss_D: 0.0082 Loss_G: 9.7609 [41/500][0/157] Loss_D: 0.0079 Loss_G: 9.2920 [41/500][1/157] Loss_D: 0.0134 Loss_G: 8.5241 [41/500][2/157] Loss_D: 0.0156 Loss_G: 8.6983 [41/500][3/157] Loss_D: 0.0250 Loss_G: 8.1148 [41/500][4/157] Loss_D: 0.0160 Loss_G: 8.3324 [41/500][5/157] Loss_D: 0.0187 Loss_G: 7.6281 [41/500][6/157] Loss_D: 0.0191 Loss_G: 7.4707 [41/500][7/157] Loss_D: 0.0092 Loss_G: 8.3976 [41/500][8/157] Loss_D: 0.0118 Loss_G: 7.9800 [41/500][9/157] Loss_D: 0.0126 Loss_G: 7.3999 [41/500][10/157] Loss_D: 0.0165 Loss_G: 7.0854 [41/500][11/157] Loss_D: 0.0095 Loss_G: 7.6392 [41/500][12/157] Loss_D: 0.0079 Loss_G: 7.3862 [41/500][13/157] Loss_D: 0.0181 Loss_G: 7.3812 [41/500][14/157] Loss_D: 0.0168 Loss_G: 6.9518 [41/500][15/157] Loss_D: 0.0094 Loss_G: 7.8525 [41/500][16/157] Loss_D: 0.0165 Loss_G: 7.3024 [41/500][17/157] Loss_D: 0.0029 Loss_G: 8.4487 [41/500][18/157] Loss_D: 0.0169 Loss_G: 7.0449 [41/500][19/157] Loss_D: 0.0167 Loss_G: 7.1307 [41/500][20/157] Loss_D: 0.0255 Loss_G: 6.7970 [41/500][21/157] Loss_D: 0.0154 Loss_G: 6.9745 [41/500][22/157] Loss_D: 0.0110 Loss_G: 6.9925
Como puede ver, se produjo un GRAN cambio en la pérdida del generador (Loss_G).
¿Alguna idea de por qué sucedió eso?
¿Alguna idea de cómo superar tal problema?
Un aumento repentino de la pérdida podría deberse a la explosión o desaparición del gradiente. Como no veo una configuración de recorte de gradiente en el código anterior, supongo que es un caso de gradiente explosivo.
Puede recortar sus degradados después de backprop como:
errG.backward() torch.nn.utils.clip_grad_norm_(netG.parameters(), max_norm=1.) optimizerG.step()
o
errG.backward() torch.nn.utils.clip_grad_value_(netG.parameters(), clip_value=1e2) optimizerG.step()
nn.utils.clip_grad_norm_
escala sus gradientes de modo que la norma L2 de todos los gradientes juntos se encuentre bajo max_norm
. Este parámetro se puede configurar dependiendo de lo que funcione mejor. Por lo general, se establece en una constante o la media de sus gradientes concatenados.
nn.utils.clip_grad_value_
clamp_
cada uno de los gradientes entre +/- clip_value
. Nuevamente, clip_value
puede ser cualquier flotador grande dependiendo de lo que funcione mejor.
Un punto importante a tener en cuenta es que estamos tratando de evitar 0 o inf en tensores para que después de backprop no resulte en nan o inf en pesos. Por lo tanto, max_norm
y clip_value
deben seleccionarse adecuadamente.
Otra posible solución es un método de inicialización diferente en lugar del normal_
que está utilizando. Puede probar la inicialización de Xavier (Glorot) o la inicialización de Kaiming (He) como:
for param in netG.parameters(): if param.dim() > 1: torch.nn.init.xavier_normal_(param, gain)
o
for param in netD.parameters(): if param.dim() > 1: torch.nn.init.kaiming_normal_(param, nonlinearity='leaky_relu')
Puede leer sobre el problema de inicialización aquí: https://pouannes.github.io/blog/initialization/
El entrenamiento GAN es intrínsecamente inestable debido al entrenamiento dinámico simultáneo de dos modelos en competencia. Intenté trazar los valores de pérdida de su pregunta y la pérdida de discriminador y generador se ve a continuación:
Mirando la pérdida y las imágenes generadas, podemos decir que el entrenamiento no logra converger. Esta falla se debe a que no se encuentra un equilibrio entre el discriminador y el generador. vemos que la pérdida para el discriminador es cercana a cero y la pérdida del generador aumenta y es inestable, lo que da como resultado imágenes basura que el discriminador puede identificar fácilmente como falsas.
Discriminator clasifica tanto los datos reales como los datos falsos del generador. La pérdida del discriminador es cuando se penaliza a sí mismo por clasificar erróneamente una instancia real como falsa o una instancia falsa (creada por el generador).
La pérdida del generador se basa en la clasificación del discriminador: se recompensa si logra engañar al discriminador y se penaliza en caso contrario. GAN como juego no cooperativo de suma cero, la ganancia es del discriminador o del generador. Si uno gana, el otro pierde. La convergencia ocurre en el equilibrio de Nash, que es cuando la acción de uno no afecta al otro. Lea más aquí https://jonathan-hui.medium.com/gan-why-it-is-so-hard-to-train-generative-advisory-networks-819a86b3750b y https://jonathan-hui.medium .com/gan-what-is-wrong-with-the-gan-cost-function-6f594162ce01 proporciona una visión más profunda de los desafíos de GAN.
La falla de convergencia también podría ocurrir debido al colapso del modo y al gradiente decreciente. Además, además de la solución de gradientes explosivos sugerida por Nihal,
Intente implementar la detención temprana en el modelo en función de las métricas Puntaje inicial , Puntaje inicial modificado, Distancia inicial de Frechet, Distancia de Wasserstein (Tomado de este documento https://arxiv.org/pdf/1802.03446.pdf ) Estas medidas ayudan a identificar la convergencia del modelo y se detendría automáticamente una vez que el modelo converge.
También se muestra que la Normalización Espectral, un tipo particular de normalización aplicada en los núcleos convolucionales, puede ayudar mucho a la estabilidad del entrenamiento. https://arxiv.org/pdf/1802.05957.pdf
Hacer que el entrenamiento del discriminador sea más difícil podría ayudar. Agregar ruido tanto a las imágenes reales como a las imágenes del generador ayuda a aumentar la complejidad del entrenamiento del discriminador.
Aumentar las iteraciones no siempre mejora el modelo. Más iteraciones de entrenamiento, más allá de cierto punto de estabilidad del entrenamiento, pueden o no dar como resultado imágenes de mayor calidad debido a la pérdida de alta varianza. Y dado que las GAN son relativamente nuevas, la dirección de investigación sobre los desafíos enfrentados aún está abierta y es discutible.