CGO 13-2: Generative Adversarial Network (GAN)

Download the Jupyter Notebook file from here.

CGO 13-2: Generative Adversarial Network (GAN)

This is a homework exercise based on implementing a Generative Adversarial Network (GAN) in pytorch.

Exercise

  1. Complete the loss function (Hint: use 2 binary loss F.binary_cross_entropy_with_logits)
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
        )
    
    def forward(self, x):
        return self.layers( x.view( -1, 28*28 ) )

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(100, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 28*28),
        )
    def forward(self, x):
        return torch.tanh( self.layers(x) ).view( -1, 1, 28, 28 )
Dtransform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((.5,), (.5,))
])

Dtrain = torchvision.datasets.MNIST('../data', train=True,  transform=Dtransform, download=True )
Dtest  = torchvision.datasets.MNIST('../data', train=False, transform=Dtransform )

batchsize = 64
train_loader = torch.utils.data.DataLoader( Dtrain, batch_size=batchsize, shuffle=True, drop_last=True )
test_loader  = torch.utils.data.DataLoader( Dtest,  batch_size=batchsize )

criterion = nn.BCELoss()
device = torch.device("cuda" if torch.cuda.device_count()>0 else "cpu")
D = Discriminator().to(device)
G = Generator().to(device)

optimizerG = torch.optim.Adam( G.parameters(), lr=0.0002, betas=(0.5, 0.999) )
optimizerD = torch.optim.Adam( D.parameters(), lr=0.0002, betas=(0.5, 0.999) )
lossD_real_log = []
lossD_fake_log = []
lossG_log = []
G.train()
D.train()
for epoch in range(5):
    t = tqdm(train_loader, desc=("Epoch %d"%(epoch+1)) )
    for x,y in t:
        # Training Discriminator
        x, y = x.to(device), y.to(device)
        pred = D(x)
        lossD_real = # TODO implement loss

        z = torch.rand( batchsize, 100 ).to(device)
        pred = D( G(z).detach() )
        lossD_fake = # TODO implement loss
        
        D.zero_grad()
        lossD = 0.5 * (lossD_real + lossD_fake)
        lossD.backward()
        optimizerD.step()

        # Training Generator
        z = torch.rand( batchsize, 100 ).to(device)
        pred = D( G( z ) )
        lossG = # TODO implement loss

        G.zero_grad()
        lossG.backward()
        optimizerG.step()
        
        lossD_real_log.append(lossD_real.item())
        lossD_fake_log.append(lossD_fake.item())
        lossG_log.append(lossG.item())
        t.set_postfix( lossD_real=lossD_real.item(), lossD_fake=lossD_fake.item(), lossG=lossG.item() )
%matplotlib inline
import matplotlib.pyplot as plt
plt.plot( lossD_real_log, label='LossD Real' )
plt.plot( lossD_fake_log, label='LossD Fake' )
plt.plot( lossG_log, label='LossG' )
plt.legend( frameon=False )
import matplotlib.pyplot as plt
import torchvision
G.eval()
with torch.no_grad():
    I = G( torch.rand( 10, 100 ).to(device) )
I = torchvision.utils.make_grid( I, nrow=5 )
plt.imshow(I.permute(1, 2, 0).to('cpu'))