CGO 12-2: Computing Spatial Support
Jupyter Notebookのファイルをここからダウンロードしてください。
CGO 12-2: Computing Spatial Support
Notes
- If the last layer is Sigmoid or TanH, we have to remove it to get a good estimate.
-
Using floats will lead to oveflow errors so we have to use doubles.
import torch
import torch.nn as nn
import torch.nn.functional as F
Pytorch super resolution example
https://github.com/pytorch/examples/blob/master/super_resolution/model.py
class Net(nn.Module):
def __init__(self, upscale_factor):
super(Net, self).__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d( 1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x
Set convolutional layer weights to 1 with bias set to 0
model = Net(2)
model.eval()
model.double()
for name, p in model.named_parameters():
if 'weight' in name:
nn.init.constant_(p,1)
elif 'bias' in name:
nn.init.constant_(p,0)
s = 32
I = torch.zeros(1,1,s,s)
I[0,0,int(s/2-1),int(s/2-1)] = 1
with torch.no_grad():
I = model(I.double())
I = F.interpolate(I, scale_factor=0.5, mode='bilinear', align_corners=True)
%matplotlib inline
import matplotlib.pyplot as plt
plt.imshow(I[0,0,:,:]/I.max())
plt.imshow(I[0,0,:,:]>0)
Deeper network (sketch simplification)
class Conv( nn.Module ):
def __init__(self, in_planes, out_planes, stride=1, kernel_size=3, padding=1 ):
super(Conv, self).__init__()
self.conv = nn.Conv2d( in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding )
self.bn = nn.BatchNorm2d( out_planes )
def forward(self, x):
return F.relu( self.bn( self.conv( x ) ) )
class Upsample( nn.Module ):
def __init__(self, in_planes, out_planes ):
super(Upsample, self).__init__()
self.conv = nn.ConvTranspose2d( in_planes, out_planes, kernel_size=4, stride=2, padding=1 )
self.bn = nn.BatchNorm2d( out_planes )
def forward(self, x):
return F.relu( self.bn( self.conv( x ) ) )
class Net( nn.Module ):
def __init__(self):
super(Net, self).__init__()
self.layers = nn.Sequential(
Conv( 1, 48, 2, 7, 3 ),
Conv( 48, 128 ),
Conv( 128, 128 ),
Conv( 128, 128, 2 ), # 1/4
Conv( 128, 256 ),
Conv( 256, 256 ),
Conv( 256, 256, 2 ), # --> 1/8
Conv( 256, 512 ),
Conv( 512, 1024 ),
Conv( 1024, 1024 ),
Conv( 1024, 1024 ),
Conv( 1024, 1024 ),
Conv( 1024, 512 ),
Conv( 512, 256 ),
Upsample( 256, 256 ),
Conv( 256, 256 ),
Conv( 256, 128 ),
Upsample( 128, 128 ),
Conv( 128, 128 ),
Conv( 128, 128 ),
Conv( 128, 48 ),
Upsample( 48, 48 ),
Conv( 48, 24 ),
nn.Conv2d( 24, 1, kernel_size=3, stride=1, padding=1 ) )
def forward(self, x):
return self.layers( x )
model = Net()
model.eval()
model.double()
for name, p in model.named_parameters():
if 'weight' in name:
nn.init.constant_(p,1)
elif 'bias' in name:
nn.init.constant_(p,0)
for name, b in model.named_buffers():
if 'running_mean' in name:
b.fill_(0)
elif 'running_var' in name:
b.fill_(1)
s = 256
I = torch.zeros(1,1,s,s)
I[0,0,int(s/2-1),int(s/2-1)] = 1
with torch.no_grad():
I = model(I.double())
plt.imshow(I[0,0,:,:]/I.max())
plt.imshow(I[0,0,:,:]>0)
Resnet-based models
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.blocks = nn.ModuleList([BasicBlock(1,32)])
for i in range(55):
self.blocks.append( BasicBlock(32,32) )
def forward(self, x):
for b in self.blocks:
x = b(x)
return x
model = Net()
model.double()
model.eval()
for name, p in model.named_parameters():
if 'weight' in name:
nn.init.constant_(p,1)
elif 'bias' in name:
nn.init.constant_(p,0)
for name, b in model.named_buffers():
if 'running_mean' in name:
b.fill_(0)
elif 'running_var' in name:
b.fill_(1)
s = 256
I = torch.zeros(1,1,s,s)
I[0,0,int(s/2-1),int(s/2-1)] = 1
with torch.no_grad():
I = model(I.double())
plt.imshow(I[0,0,:,:]/I.max())
plt.imshow(I[0,0,:,:]>0)