CGO 13-2: Generative Adversarial Network (GAN)
Download the Jupyter Notebook file from here.
CGO 13-2: Generative Adversarial Network (GAN)
This is a homework exercise based on implementing a Generative Adversarial Network (GAN) in pytorch.
Exercise
- Complete the loss function (Hint: use 2 binary loss
F.binary_cross_entropy_with_logits
)
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.layers = nn.Sequential(
nn.Linear(28*28, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
)
def forward(self, x):
return self.layers( x.view( -1, 28*28 ) )
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.layers = nn.Sequential(
nn.Linear(100, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 28*28),
)
def forward(self, x):
return torch.tanh( self.layers(x) ).view( -1, 1, 28, 28 )
Dtransform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((.5,), (.5,))
])
Dtrain = torchvision.datasets.MNIST('../data', train=True, transform=Dtransform, download=True )
Dtest = torchvision.datasets.MNIST('../data', train=False, transform=Dtransform )
batchsize = 64
train_loader = torch.utils.data.DataLoader( Dtrain, batch_size=batchsize, shuffle=True, drop_last=True )
test_loader = torch.utils.data.DataLoader( Dtest, batch_size=batchsize )
criterion = nn.BCELoss()
device = torch.device("cuda" if torch.cuda.device_count()>0 else "cpu")
D = Discriminator().to(device)
G = Generator().to(device)
optimizerG = torch.optim.Adam( G.parameters(), lr=0.0002, betas=(0.5, 0.999) )
optimizerD = torch.optim.Adam( D.parameters(), lr=0.0002, betas=(0.5, 0.999) )
lossD_real_log = []
lossD_fake_log = []
lossG_log = []
G.train()
D.train()
for epoch in range(5):
t = tqdm(train_loader, desc=("Epoch %d"%(epoch+1)) )
for x,y in t:
# Training Discriminator
x, y = x.to(device), y.to(device)
pred = D(x)
lossD_real = # TODO implement loss
z = torch.rand( batchsize, 100 ).to(device)
pred = D( G(z).detach() )
lossD_fake = # TODO implement loss
D.zero_grad()
lossD = 0.5 * (lossD_real + lossD_fake)
lossD.backward()
optimizerD.step()
# Training Generator
z = torch.rand( batchsize, 100 ).to(device)
pred = D( G( z ) )
lossG = # TODO implement loss
G.zero_grad()
lossG.backward()
optimizerG.step()
lossD_real_log.append(lossD_real.item())
lossD_fake_log.append(lossD_fake.item())
lossG_log.append(lossG.item())
t.set_postfix( lossD_real=lossD_real.item(), lossD_fake=lossD_fake.item(), lossG=lossG.item() )
%matplotlib inline
import matplotlib.pyplot as plt
plt.plot( lossD_real_log, label='LossD Real' )
plt.plot( lossD_fake_log, label='LossD Fake' )
plt.plot( lossG_log, label='LossG' )
plt.legend( frameon=False )
import matplotlib.pyplot as plt
import torchvision
G.eval()
with torch.no_grad():
I = G( torch.rand( 10, 100 ).to(device) )
I = torchvision.utils.make_grid( I, nrow=5 )
plt.imshow(I.permute(1, 2, 0).to('cpu'))