Skip to main content

Command Palette

Search for a command to run...

Implementing a ResNet-34 CNN Using PyTorch

Updated
7 min read
Implementing a ResNet-34 CNN Using PyTorch

A while ago, I authored an article Implementing ResNet CNN that provided a detailed explanation of ResNet Convolutional Neural Networks (CNN) along with an implementation using TensorFlow. In this upcoming article, we will take a closer look at ResNet34, a specific variant of the ResNet architecture, and implement it using PyTorch. This will allow us to explore the unique features and benefits of PyTorch while leveraging the powerful capabilities of ResNet34 for various tasks in deep learning.

ResNet34 Architecture

The ResNet34 class constructs a complete ResNet-34 network consisting of 34 layers. Here’s a breakdown of its structure:

1. Stem (Initial Layers):

  • Conv2d: Converts 3 input channels to 64 filters with a 7×7 kernel and a stride of 2, which downsamples the image by a factor of 2.

  • BatchNorm2d + ReLU: Applies batch normalization followed by the ReLU activation function.

  • MaxPool2d: Further downsamples the image with a stride of 2.

2. Residual Blocks (Core):

The network comprises ResidualUnits grouped into four sections:

  • Stage 1: 3 units with 64 filters (stride = 1, maintaining spatial dimensions).

  • Stage 2: 4 units with 128 filters (the first unit has a stride of 2 for downsampling, while the remaining units have a stride of 1).

  • Stage 3: 6 units with 256 filters (the first unit has a stride of 2, while the rest have a stride of 1).

  • Stage 4: 3 units with 512 filters (the first unit has a stride of 2, while the rest have a stride of 1).

In total, there are 3 + 4 + 6 + 3 = 16 residual blocks, which results in 32 convolutional layers plus 2 initial convolutional layers, equating to 34 layers overall.

Stride Logic:

  • A stride of 2 is used when the number of filters changes, which reduces spatial resolution and increases the number of channels.

  • A stride of 1 is maintained when the number of filters remains the same.

3. Classification Head:

  • AdaptiveAvgPool2d: Performs global average pooling, resulting in an output shape of (batch_size, 512, 1, 1).

  • Flatten: Converts the output to a shape of (batch_size, 512).

  • LazyLinear: Maps the flattened output from 512 to 10 classes.

Key Design Points:

  • Progressively reduces spatial dimensions (56 → 28 → 14 → 7) while increasing channels

  • Each stage transition uses stride=2 to halve dimensions

  • Skip connections allow gradients to flow through all 34 layers

  • Total parameters: ~23.5 million

Lets implement this in Pytorch

Import the packages

import numpy as np
import torch
from sklearn.datasets import load_sample_images
import matplotlib.pyplot as plt
import torchvision
import torch.nn as nn
import torchvision.transforms.v2 as T
from functools import partial
import torchmetrics
import torch.nn.functional as F

ResidualUnit

class ResidualUnit(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        DefaultConv2d = partial(
            nn.Conv2d, kernel_size=3, stride=1, padding=1, bias=False)
        self.main_layers = nn.Sequential(
            DefaultConv2d(in_channels, out_channels, stride=stride),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            DefaultConv2d(out_channels, out_channels),
            nn.BatchNorm2d(out_channels),
        )
        if stride > 1:
            self.skip_connection = nn.Sequential(
                DefaultConv2d(in_channels, out_channels, kernel_size=1,
                              stride=stride, padding=0),
                nn.BatchNorm2d(out_channels),
            )
        else:
            self.skip_connection = nn.Identity()

    def forward(self, inputs):
        return F.relu(self.main_layers(inputs) + self.skip_connection(inputs))

The ResidualUnit class implements a residual block, which is the core building block of ResNet (Residual Networks). Here's the breakdown:Key Components

1. Main Path

  • Two convolutional blocks in sequence:

    • Conv2d → BatchNorm2d → ReLU → Conv2d → BatchNorm2d
  • The first Conv2d uses the stride parameter (for downsampling if needed)

  • The second Conv2d always uses stride=1

2. Skip Connection

  • If stride > 1: Creates a 1×1 convolution with the specified stride + batch norm (adjusts dimensions and spatial resolution)

  • If stride = 1: Uses nn.Identity() (passes input unchanged)

  • This ensures the skip connection has the same dimensions as the main path output

4. Forward Pass

  • Adds the output of main_layers and skip_connection

  • Applies ReLU activation to the sum

The key innovation is addition of the skip connection to the main path. This allows:

  • Gradients to bypass layers during backpropagation (easier training)

  • The network to learn residual mappings (differences) rather than full transformations

  • Training of very deep networks without degradation

ResNet34

ResNet34 class builds complete ResNet34 architecture leveraging ResidualUnit class . Architecture diagram for ResNet34 is show above earlier.

class ResNet34(nn.Module):
    def __init__(self):
        super().__init__()
        layers = [
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2,
                      padding=3, bias=False),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        ]
        prev_filters = 64
        for filters in [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3:
            stride = 1 if filters == prev_filters else 2
            layers.append(ResidualUnit(prev_filters, filters, stride=stride))
            prev_filters = filters
        layers += [
            nn.AdaptiveAvgPool2d(output_size=1),
            nn.Flatten(),
            nn.LazyLinear(10),
        ]
        self.resnet = nn.Sequential(*layers)

    def forward(self, inputs):
        return self.resnet(inputs)

Loading the CIFAR-10 dataset

# Load CIFAR-10 Dataset
transform = T.Compose([
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load full training set
train_valid_dataset = torchvision.datasets.CIFAR10(root='./datasets', train=True, 
                                            download=False, transform=transform)
# Load test set for validation
test_dataset = torchvision.datasets.CIFAR10(root='./datasets', train=False, 
                                            download=False, transform=transform)

torch.manual_seed(42)
train_dataset, valid_dataset = torch.utils.data.random_split(
    train_valid_dataset, [45_000, 5_000]
)

# Create Data Loaders
batch_size = 128
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(
    dataset=valid_dataset, batch_size=batch_size, shuffle=False
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=batch_size, shuffle=False
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(valid_dataset)}")
print(f"Testing  samples: {len(test_dataset)}")

The below code sets up the training environment for your ResNet34 model

  • Creates a new ResNet34 instance

  • Moves the model to the selected device (GPU or CPU)

  • Defines the loss function for multi-class classification

# Setup for Training
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize model
model = ResNet34().to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

The train_epoch function trains the model for one complete pass through the training dataset and returns average loss and accuracy for the entire epoch

Step-by-Step Breakdown

1. Set Model to Training Mode

2. Initialize Tracking Variables

3. Loops Through Each Batch

5. Does Forward Pass

6. Does Backward Pass (Compute Gradients)

  • Updates all model weights using computed gradients

  • Moves weights in direction that reduces loss

8. Track Metrics

# Training function
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    avg_loss = total_loss / len(train_loader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy

Train The Model

# Train for 10 epochs
num_epochs = 10
train_losses = []
train_accs = []
valid_losses = []
valid_accs = []

print("Starting training for 10 epochs...")
print("=" * 80)

for epoch in range(num_epochs):
    # Train for one epoch
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    train_accs.append(train_acc)

    # Validate for one epoch
    valid_loss, valid_acc = test_epoch(model, valid_loader, criterion, device)
    valid_losses.append(valid_loss)
    valid_accs.append(valid_acc)

    # Print results for each epoch
    print(f"Epoch [{epoch+1:2d}/{num_epochs}] | "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:6.2f}% | "
          f"Valid Loss: {valid_loss:.4f} | Valid Acc: {valid_acc:6.2f}%")

print("=" * 80) 
print("Training completed!")

print("Model saved to './my_resnet34_checkpoint.pt'")

# Save the trained modeltorch.save(model.state_dict(), './my_resnet34_checkpoint.pt')
Starting training for 10 epochs...
================================================================================
Epoch [ 1/10] | Train Loss: 0.1853 | Train Acc:  93.58% | Valid Loss: 0.6287 | Valid Acc:  82.22%
Epoch [ 2/10] | Train Loss: 0.1498 | Train Acc:  94.80% | Valid Loss: 0.7269 | Valid Acc:  80.48%
Epoch [ 3/10] | Train Loss: 0.1282 | Train Acc:  95.52% | Valid Loss: 0.7559 | Valid Acc:  80.24%
Epoch [ 4/10] | Train Loss: 0.0961 | Train Acc:  96.71% | Valid Loss: 0.8131 | Valid Acc:  80.16%
Epoch [ 5/10] | Train Loss: 0.0948 | Train Acc:  96.57% | Valid Loss: 0.8196 | Valid Acc:  80.94%
Epoch [ 6/10] | Train Loss: 0.0853 | Train Acc:  97.06% | Valid Loss: 0.8924 | Valid Acc:  79.26%
Epoch [ 7/10] | Train Loss: 0.0755 | Train Acc:  97.44% | Valid Loss: 0.8582 | Valid Acc:  80.14%
Epoch [ 8/10] | Train Loss: 0.0661 | Train Acc:  97.79% | Valid Loss: 0.9182 | Valid Acc:  80.18%
Epoch [ 9/10] | Train Loss: 0.0653 | Train Acc:  97.71% | Valid Loss: 0.9218 | Valid Acc:  80.42%
Epoch [10/10] | Train Loss: 0.0518 | Train Acc:  98.22% | Valid Loss: 0.9642 | Valid Acc:  79.80%
================================================================================
Training completed!
Model saved to './my_resnet34_checkpoint.pt'

Chart of losses and accuracy for training and validation data

\====================================================================== TRAINING SUMMARY \======================================================================

Total Epochs Trained: 10

Final Metrics:

Train Loss: 0.0518 Train Accuracy: 98.22% Valid Loss: 0.9642 Valid Accuracy: 79.80%

Best Validation Metrics: Best Valid Accuracy: 82.22% (Epoch 1) Best Valid Loss: 0.6287 (Epoch 1) ======================================================================

Evaluating on test Data

# Evaluate on Test Data
print("Evaluating model on test data...")
test_loss, test_acc = test_epoch(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")
Evaluating model on test data...
Test Loss: 1.2146
Test Accuracy: 76.64%