CGO 12-2: Computing Spatial Support

Download the Jupyter Notebook file from here.

CGO 12-2: Computing Spatial Support

Notes

  1. If the last layer is Sigmoid or TanH, we have to remove it to get a good estimate.
  2. Using floats will lead to oveflow errors so we have to use doubles.
import torch
import torch.nn as nn
import torch.nn.functional as F

Pytorch super resolution example

https://github.com/pytorch/examples/blob/master/super_resolution/model.py

class Net(nn.Module):
    def __init__(self, upscale_factor):
        super(Net, self).__init__()

        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d( 1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

Set convolutional layer weights to 1 with bias set to 0

model = Net(2)
model.eval()
model.double()
for name, p in model.named_parameters():
    if 'weight' in name:
        nn.init.constant_(p,1)
    elif 'bias' in name:
        nn.init.constant_(p,0)
s = 32
I = torch.zeros(1,1,s,s)
I[0,0,int(s/2-1),int(s/2-1)] = 1
with torch.no_grad():
    I = model(I.double())
I = F.interpolate(I, scale_factor=0.5, mode='bilinear', align_corners=True)
%matplotlib inline
import matplotlib.pyplot as plt
plt.imshow(I[0,0,:,:]/I.max())
plt.imshow(I[0,0,:,:]>0)

Deeper network (sketch simplification)

class Conv( nn.Module ):
    def __init__(self, in_planes, out_planes, stride=1, kernel_size=3, padding=1 ):
        super(Conv, self).__init__()
        self.conv    = nn.Conv2d( in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding )
        self.bn      = nn.BatchNorm2d( out_planes )
    def forward(self, x):
        return F.relu( self.bn( self.conv( x ) ) )
class Upsample( nn.Module ):
    def __init__(self, in_planes, out_planes ):
        super(Upsample, self).__init__()
        self.conv = nn.ConvTranspose2d( in_planes, out_planes, kernel_size=4, stride=2, padding=1 )
        self.bn   = nn.BatchNorm2d( out_planes )
    def forward(self, x):
        return F.relu( self.bn( self.conv( x ) ) )
class Net( nn.Module ):
    def __init__(self):
        super(Net, self).__init__()
        self.layers = nn.Sequential(
             Conv(  1,  48, 2, 7, 3 ),
             Conv( 48, 128 ),
             Conv( 128, 128 ),
             Conv( 128, 128, 2 ), # 1/4
             Conv( 128, 256 ),
             Conv( 256, 256 ),
             Conv( 256, 256, 2 ), # --> 1/8
             Conv( 256, 512 ),
             Conv( 512, 1024 ),
             Conv( 1024, 1024 ),
             Conv( 1024, 1024 ),
             Conv( 1024, 1024 ),
             Conv( 1024, 512 ),
             Conv( 512, 256 ),
             Upsample( 256, 256 ),
             Conv( 256, 256 ),
             Conv( 256, 128 ),
             Upsample( 128, 128 ),
             Conv( 128, 128 ),
             Conv( 128, 128 ),
             Conv( 128, 48 ),
             Upsample( 48, 48 ),
             Conv( 48, 24 ),
             nn.Conv2d( 24, 1, kernel_size=3, stride=1, padding=1 ) )
    def forward(self, x):
        return self.layers( x )
model = Net()
model.eval()
model.double()
for name, p in model.named_parameters():
    if 'weight' in name:
        nn.init.constant_(p,1)
    elif 'bias' in name:
        nn.init.constant_(p,0)
for name, b in model.named_buffers():
    if 'running_mean' in name:
        b.fill_(0)
    elif 'running_var' in name:
        b.fill_(1)
s = 256
I = torch.zeros(1,1,s,s)
I[0,0,int(s/2-1),int(s/2-1)] = 1
with torch.no_grad():
    I = model(I.double())
plt.imshow(I[0,0,:,:]/I.max())
plt.imshow(I[0,0,:,:]>0)

Resnet-based models

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
    
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.blocks = nn.ModuleList([BasicBlock(1,32)])
        for i in range(55):
            self.blocks.append( BasicBlock(32,32) )

    def forward(self, x):
        for b in self.blocks:
            x = b(x)
        return x
model = Net()
model.double()
model.eval()
for name, p in model.named_parameters():
    if 'weight' in name:
        nn.init.constant_(p,1)
    elif 'bias' in name:
        nn.init.constant_(p,0)
for name, b in model.named_buffers():
    if 'running_mean' in name:
        b.fill_(0)
    elif 'running_var' in name:
        b.fill_(1)
s = 256
I = torch.zeros(1,1,s,s)
I[0,0,int(s/2-1),int(s/2-1)] = 1
with torch.no_grad():
    I = model(I.double())
plt.imshow(I[0,0,:,:]/I.max())
plt.imshow(I[0,0,:,:]>0)