Implementing a ResNet-34 CNN Using PyTorch

A while ago, I authored an article Implementing ResNet CNN that provided a detailed explanation of ResNet Convolutional Neural Networks (CNN) along with an implementation using TensorFlow. In this upcoming article, we will take a closer look at ResNet34, a specific variant of the ResNet architecture, and implement it using PyTorch. This will allow us to explore the unique features and benefits of PyTorch while leveraging the powerful capabilities of ResNet34 for various tasks in deep learning.
ResNet34 Architecture

The ResNet34 class constructs a complete ResNet-34 network consisting of 34 layers. Here’s a breakdown of its structure:
1. Stem (Initial Layers):
Conv2d: Converts 3 input channels to 64 filters with a 7×7 kernel and a stride of 2, which downsamples the image by a factor of 2.
BatchNorm2d + ReLU: Applies batch normalization followed by the ReLU activation function.
MaxPool2d: Further downsamples the image with a stride of 2.
2. Residual Blocks (Core):
The network comprises ResidualUnits grouped into four sections:
Stage 1: 3 units with 64 filters (stride = 1, maintaining spatial dimensions).
Stage 2: 4 units with 128 filters (the first unit has a stride of 2 for downsampling, while the remaining units have a stride of 1).
Stage 3: 6 units with 256 filters (the first unit has a stride of 2, while the rest have a stride of 1).
Stage 4: 3 units with 512 filters (the first unit has a stride of 2, while the rest have a stride of 1).
In total, there are 3 + 4 + 6 + 3 = 16 residual blocks, which results in 32 convolutional layers plus 2 initial convolutional layers, equating to 34 layers overall.
Stride Logic:
A stride of 2 is used when the number of filters changes, which reduces spatial resolution and increases the number of channels.
A stride of 1 is maintained when the number of filters remains the same.
3. Classification Head:
AdaptiveAvgPool2d: Performs global average pooling, resulting in an output shape of (batch_size, 512, 1, 1).
Flatten: Converts the output to a shape of (batch_size, 512).
LazyLinear: Maps the flattened output from 512 to 10 classes.
Key Design Points:
Progressively reduces spatial dimensions (56 → 28 → 14 → 7) while increasing channels
Each stage transition uses stride=2 to halve dimensions
Skip connections allow gradients to flow through all 34 layers
Total parameters: ~23.5 million
Lets implement this in Pytorch
Import the packages
import numpy as np
import torch
from sklearn.datasets import load_sample_images
import matplotlib.pyplot as plt
import torchvision
import torch.nn as nn
import torchvision.transforms.v2 as T
from functools import partial
import torchmetrics
import torch.nn.functional as F
ResidualUnit

class ResidualUnit(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
DefaultConv2d = partial(
nn.Conv2d, kernel_size=3, stride=1, padding=1, bias=False)
self.main_layers = nn.Sequential(
DefaultConv2d(in_channels, out_channels, stride=stride),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
DefaultConv2d(out_channels, out_channels),
nn.BatchNorm2d(out_channels),
)
if stride > 1:
self.skip_connection = nn.Sequential(
DefaultConv2d(in_channels, out_channels, kernel_size=1,
stride=stride, padding=0),
nn.BatchNorm2d(out_channels),
)
else:
self.skip_connection = nn.Identity()
def forward(self, inputs):
return F.relu(self.main_layers(inputs) + self.skip_connection(inputs))
The ResidualUnit class implements a residual block, which is the core building block of ResNet (Residual Networks). Here's the breakdown:Key Components
1. Main Path
Two convolutional blocks in sequence:
- Conv2d → BatchNorm2d → ReLU → Conv2d → BatchNorm2d
The first Conv2d uses the stride parameter (for downsampling if needed)
The second Conv2d always uses stride=1
2. Skip Connection
If stride
> 1:Creates a 1×1 convolution with the specified stride + batch norm (adjusts dimensions and spatial resolution)If stride
= 1:Uses nn.Identity() (passes input unchanged)This ensures the skip connection has the same dimensions as the main path output
4. Forward Pass
Adds the output of
main_layersandskip_connectionApplies ReLU activation to the sum
The key innovation is addition of the skip connection to the main path. This allows:
Gradients to bypass layers during backpropagation (easier training)
The network to learn residual mappings (differences) rather than full transformations
Training of very deep networks without degradation
ResNet34
ResNet34 class builds complete ResNet34 architecture leveraging ResidualUnit class . Architecture diagram for ResNet34 is show above earlier.
class ResNet34(nn.Module):
def __init__(self):
super().__init__()
layers = [
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2,
padding=3, bias=False),
nn.BatchNorm2d(num_features=64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
]
prev_filters = 64
for filters in [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3:
stride = 1 if filters == prev_filters else 2
layers.append(ResidualUnit(prev_filters, filters, stride=stride))
prev_filters = filters
layers += [
nn.AdaptiveAvgPool2d(output_size=1),
nn.Flatten(),
nn.LazyLinear(10),
]
self.resnet = nn.Sequential(*layers)
def forward(self, inputs):
return self.resnet(inputs)
Loading the CIFAR-10 dataset
# Load CIFAR-10 Dataset
transform = T.Compose([
T.ToImage(),
T.ToDtype(torch.float32, scale=True),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load full training set
train_valid_dataset = torchvision.datasets.CIFAR10(root='./datasets', train=True,
download=False, transform=transform)
# Load test set for validation
test_dataset = torchvision.datasets.CIFAR10(root='./datasets', train=False,
download=False, transform=transform)
torch.manual_seed(42)
train_dataset, valid_dataset = torch.utils.data.random_split(
train_valid_dataset, [45_000, 5_000]
)
# Create Data Loaders
batch_size = 128
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(
dataset=valid_dataset, batch_size=batch_size, shuffle=False
)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset, batch_size=batch_size, shuffle=False
)
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(valid_dataset)}")
print(f"Testing samples: {len(test_dataset)}")
The below code sets up the training environment for your ResNet34 model
Creates a new ResNet34 instance
Moves the model to the selected device (GPU or CPU)
Defines the loss function for multi-class classification
# Setup for Training
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
# Initialize model
model = ResNet34().to(device)
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
The train_epoch function trains the model for one complete pass through the training dataset and returns average loss and accuracy for the entire epoch
Step-by-Step Breakdown
1. Set Model to Training Mode
2. Initialize Tracking Variables
3. Loops Through Each Batch
5. Does Forward Pass
6. Does Backward Pass (Compute Gradients)
Updates all model weights using computed gradients
Moves weights in direction that reduces loss
8. Track Metrics
# Training function
def train_epoch(model, train_loader, criterion, optimizer, device):
model.train()
total_loss = 0
correct = 0
total = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
avg_loss = total_loss / len(train_loader)
accuracy = 100 * correct / total
return avg_loss, accuracy
Train The Model
# Train for 10 epochs
num_epochs = 10
train_losses = []
train_accs = []
valid_losses = []
valid_accs = []
print("Starting training for 10 epochs...")
print("=" * 80)
for epoch in range(num_epochs):
# Train for one epoch
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
train_losses.append(train_loss)
train_accs.append(train_acc)
# Validate for one epoch
valid_loss, valid_acc = test_epoch(model, valid_loader, criterion, device)
valid_losses.append(valid_loss)
valid_accs.append(valid_acc)
# Print results for each epoch
print(f"Epoch [{epoch+1:2d}/{num_epochs}] | "
f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:6.2f}% | "
f"Valid Loss: {valid_loss:.4f} | Valid Acc: {valid_acc:6.2f}%")
print("=" * 80)
print("Training completed!")
print("Model saved to './my_resnet34_checkpoint.pt'")
# Save the trained modeltorch.save(model.state_dict(), './my_resnet34_checkpoint.pt')
Starting training for 10 epochs...
================================================================================
Epoch [ 1/10] | Train Loss: 0.1853 | Train Acc: 93.58% | Valid Loss: 0.6287 | Valid Acc: 82.22%
Epoch [ 2/10] | Train Loss: 0.1498 | Train Acc: 94.80% | Valid Loss: 0.7269 | Valid Acc: 80.48%
Epoch [ 3/10] | Train Loss: 0.1282 | Train Acc: 95.52% | Valid Loss: 0.7559 | Valid Acc: 80.24%
Epoch [ 4/10] | Train Loss: 0.0961 | Train Acc: 96.71% | Valid Loss: 0.8131 | Valid Acc: 80.16%
Epoch [ 5/10] | Train Loss: 0.0948 | Train Acc: 96.57% | Valid Loss: 0.8196 | Valid Acc: 80.94%
Epoch [ 6/10] | Train Loss: 0.0853 | Train Acc: 97.06% | Valid Loss: 0.8924 | Valid Acc: 79.26%
Epoch [ 7/10] | Train Loss: 0.0755 | Train Acc: 97.44% | Valid Loss: 0.8582 | Valid Acc: 80.14%
Epoch [ 8/10] | Train Loss: 0.0661 | Train Acc: 97.79% | Valid Loss: 0.9182 | Valid Acc: 80.18%
Epoch [ 9/10] | Train Loss: 0.0653 | Train Acc: 97.71% | Valid Loss: 0.9218 | Valid Acc: 80.42%
Epoch [10/10] | Train Loss: 0.0518 | Train Acc: 98.22% | Valid Loss: 0.9642 | Valid Acc: 79.80%
================================================================================
Training completed!
Model saved to './my_resnet34_checkpoint.pt'
Chart of losses and accuracy for training and validation data

\====================================================================== TRAINING SUMMARY \======================================================================
Total Epochs Trained: 10
Final Metrics:
Train Loss: 0.0518 Train Accuracy: 98.22% Valid Loss: 0.9642 Valid Accuracy: 79.80%
Best Validation Metrics: Best Valid Accuracy: 82.22% (Epoch 1) Best Valid Loss: 0.6287 (Epoch 1) ======================================================================
Evaluating on test Data
# Evaluate on Test Data
print("Evaluating model on test data...")
test_loss, test_acc = test_epoch(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")
Evaluating model on test data...
Test Loss: 1.2146
Test Accuracy: 76.64%



