Creating ResNet18 from scratch
Recently I made some ResNet18 from scratch so I could modify it. Before I showed what is inside ResNets but in low detail.
Few facts
There are several popular models:
- ResNet18
- ResNet34
- ResNet50
- ResNet101
For ResNet18 and ResNet34 we use basic blocks, and for ResNet50 and ResNet101 we use bottleneck blocks.
We also have identity blocks and skip connection blocks. The difference is ResIdentity
blocks have two and ResSkip
blocks have three convolutions inside.
import torch
import torch.nn as nn
class ResIdentity(nn.Module):
# so called identity block with empty skip connection
def __init__(self, ni): # in out channels
super().__init__()
self.conv1 = nn.Conv2d(ni, ni, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(ni)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(ni, ni, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(ni)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = x + identity
x = self.relu(x)
return x
class ResSkip(nn.Module):
# Tiny conv skip connection
def __init__(self, ni, no): # in channels, out channels
super().__init__()
self.conv1 = nn.Conv2d(ni, no, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(no)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(no, no, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(no)
self.skip_conv1 = nn.Conv2d(ni, no, kernel_size=1, stride=2, padding=0, bias=False)
self.skip_bn1 = nn.BatchNorm2d(no)
def forward(self, x):
skip = self.skip_conv1(x)
skip = self.skip_bn1(skip)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = x + skip
x = self.relu(x)
return x
What we call basic blocks has two variants, first made by N identity blocks.
class NIdentityBlocks(nn.Module):
def __init__(self, ni, repeat=2):
super().__init__()
self.block = nn.Sequential()
for _ in range(repeat):
self.block.add_module(f"ResIdentity{_}", ResIdentity(ni))
def forward(self, x):
x = self.block(x)
return x
Second made by skip blocks followed by N identity blocks.
class SkipAndNIdentityBlocks(nn.Module):
def __init__(self, ni, no, repeat=2):
super().__init__()
self.block = nn.Sequential()
self.block.add_module("ResSkip", ResSkip(ni, no))
for _ in range(repeat-1):
self.block.add_module(f"ResIdentity{_}", ResIdentity(no))
def forward(self, x):
x = self.block(x)
return x
The idea of ResNet head and tail corresponds to the encoder and decoder.
class ResNetHead(nn.Module):
def __init__(self, ni, no):
super().__init__()
self.conv = nn.Conv2d(ni, no, kernel_size=7, stride=2, padding=3, bias=False)
self.bn = nn.BatchNorm2d(no)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.maxpool(x)
return x
class ResNetTail(nn.Module):
def __init__(self, ni, no):
super().__init__()
self.avg = nn.AdaptiveAvgPool2d((1, 1))
self.lin = nn.Linear(ni, no)
def forward(self, x):
x = self.avg(x)
x = x.view(x.size(0), -1)
x = self.lin(x)
return x
Lastly the full ResNet is a composition. The few things we define is the number of inputs (usually 3) and the number of outputs (usually 1000).
The channels 64 is the initial number of planes (channels) and in the end we have the 512 channels. All ResNet architectures will have these l0, l1, l2, l3 layers where l1, l2, and l3 will double the channels by factor 2.
class ResNet(nn.Module):
def __init__(self, ni, no, repeat):
super().__init__()
self.head = ResNetHead(ni, 64)
self.l0 = NIdentityBlocks(64, repeat[0])
self.l1 = SkipAndNIdentityBlocks(64, 128, repeat[1])
self.l2 = SkipAndNIdentityBlocks(128, 256, repeat[2])
self.l3 = SkipAndNIdentityBlocks(256, 512, repeat[3])
self.tail = ResNetTail(512, no)
def forward(self, x):
x = self.head(x)
x = self.l0(x)
x = self.l1(x)
x = self.l2(x)
x = self.l3(x)
x = self.tail(x)
return x
The Check
i = torch.rand(85, 3, 32,32)
my_resnet18 = ResNet(3,1000, [2,2,2,2])
o = my_resnet18(i)
# print(o.size())
# print(my_resnet18)
nparams = sum(p.numel() for p in my_resnet18.parameters())
print(nparams) # 11689512
You will find my_resnet18
has 11689512 parameters.
This is the same as in PyTorch.
import torchvision.models as models
resnet18 = models.resnet18(False)
nparams = sum(p.numel() for p in resnet18.parameters())
print(nparams) # 11689512
Initialization tip
We haven’t initialized the conv layers in a custom way, but to do that I would use this function:
model = ResNet(3, 10, [2,2,2,2]).to('cuda')
for m in model.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_uniform_(m.weight, a=math.sqrt(3) )
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.constant_(m.weight, 0.975)
torch.nn.init.constant_(m.bias, 0.125)