Lecture 7: Convolutional Neural Networks#

Handling image data

Joaquin Vanschoren, Eindhoven University of Technology

Overview#

  • Image convolution

  • Convolutional neural networks

  • Data augmentation

  • Model interpretation

  • Using pre-trained networks (transfer learning)

Hide code cell source
# Auto-setup when running on Google Colab
import os
import tensorflow as tf
if 'google.colab' in str(get_ipython()) and not os.path.exists('/content/master'):
    !git clone -q https://github.com/ML-course/master.git /content/master
    !pip --quiet install -r /content/master/requirements_colab.txt
    %cd master/notebooks

# Global imports and settings
%matplotlib inline
from preamble import *
interactive = True # Set to True for interactive plots 
if interactive:
    fig_scale = 0.5
    plt.rcParams.update(print_config)
else: # For printing
    fig_scale = 0.4
    plt.rcParams.update(print_config)
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}
Hide code cell source
import pickle 
data_dir = '../data/cats-vs-dogs_small'
model_dir = '../data/models'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
    
with open("../data/histories.pkl", "rb") as f:
    histories = pickle.load(f)

Convolutions#

  • Operation that transforms an image by sliding a smaller image (called a filter or kernel ) over the image and multiplying the pixel values

    • Slide an \(n\) x \(n\) filter over \(n\) x \(n\) patches of the original image

    • Every pixel is replaced by the sum of the element-wise products of the values of the image patch around that pixel and the kernel

# kernel and image_patch are n x n matrices
pixel_out = np.sum(kernel * image_patch)
ml
Hide code cell source
from __future__ import print_function
import ipywidgets as widgets
from ipywidgets import interact, interact_manual, Dropdown
from skimage import color


# Visualize convolution. See https://tonysyu.github.io/
def iter_pixels(image):
    """ Yield pixel position (row, column) and pixel intensity. """
    height, width = image.shape[:2]
    for i in range(height):
        for j in range(width):
            yield (i, j), image[i, j]
            
# Visualize result
def imshow_pair(image_pair, titles=('', ''), figsize=(8, 4), **kwargs):
    fig, axes = plt.subplots(ncols=2, figsize=figsize)
    for ax, img, label in zip(axes.ravel(), image_pair, titles):
        ax.imshow(img, **kwargs)
        ax.set_title(label, fontdict={'fontsize':32*fig_scale})
        ax.set_xticks([])
        ax.set_yticks([])
        
# Visualize result
def imshow_triple(axes, image_pair, titles=('', '', ''), figsize=(8, 4), **kwargs):
    for ax, img, label in zip(axes, image_pair, titles):
        ax.imshow(img, **kwargs)
        ax.set_title(label, fontdict={'fontsize':10*fig_scale})
        ax.set_xticks([])
        ax.set_yticks([])
        
# Zero-padding
def padding_for_kernel(kernel):
    """ Return the amount of padding needed for each side of an image.

    For example, if the returned result is [1, 2], then this means an
    image should be padded with 1 extra row on top and bottom, and 2
    extra columns on the left and right.
    """
    # Slice to ignore RGB channels if they exist.
    image_shape = kernel.shape[:2]
    # We only handle kernels with odd dimensions so make sure that's true.
    # (The "center" pixel of an even number of pixels is arbitrary.)
    assert all((size % 2) == 1 for size in image_shape)
    return [(size - 1) // 2 for size in image_shape]
def add_padding(image, kernel):
    h_pad, w_pad = padding_for_kernel(kernel)
    return np.pad(image, ((h_pad, h_pad), (w_pad, w_pad)),
                  mode='constant', constant_values=0)
def remove_padding(image, kernel):
    inner_region = []  # A 2D slice for grabbing the inner image region
    for pad in padding_for_kernel(kernel):
        slice_i = np.s_[:] if pad == 0 else np.s_[pad: -pad]
        inner_region.append(slice_i)
    return image # [inner_region] # Broken in numpy 1.24, doesn't seem necessary

# Slice windows
def window_slice(center, kernel):
    r, c = center
    r_pad, c_pad = padding_for_kernel(kernel)
    # Slicing is (inclusive, exclusive) so add 1 to the stop value
    return np.s_[r-r_pad:r+r_pad+1, c-c_pad:c+c_pad+1]
        

# Apply convolution kernel to image patch
def apply_kernel(center, kernel, original_image):
    image_patch = original_image[window_slice(center, kernel)]
    # An element-wise multiplication followed by the sum
    return np.sum(kernel * image_patch)

# Move kernel over the image
def iter_kernel_labels(image, kernel):
    original_image = image
    image = add_padding(original_image, kernel)
    i_pad, j_pad = padding_for_kernel(kernel)

    for (i, j), pixel in iter_pixels(original_image):
        # Shift the center of the kernel to ignore padded border.
        i += i_pad
        j += j_pad
        mask = np.zeros(image.shape, dtype=int)  # Background = 0
        mask[window_slice((i, j), kernel)] = kernel   # Kernel = 1
        #mask[i, j] = 2                           # Kernel-center = 2
        yield (i, j), mask

# Visualize kernel as it moves over the image
def visualize_kernel(kernel_labels, image):
    return kernel_labels + image #color.label2rgb(kernel_labels, image, bg_label=0)

def convolution_demo(image, kernels, **kwargs):
    # Dropdown for selecting kernels
    kernel_names = list(kernels.keys())
    kernel_selector = Dropdown(options=kernel_names, description='Kernel:')
    
    def update_convolution(kernel_name):
        kernel = kernels[kernel_name]  # Get the selected kernel
        gen_kernel_labels = iter_kernel_labels(image, kernel)
        
        image_cache = []
        image_padded = add_padding(image, kernel)
        
        def convolution_step(i_step=0):
            while i_step >= len(image_cache):
                filtered_prev = image_padded if i_step == 0 else image_cache[-1][1]
                filtered = filtered_prev.copy()
                
                center, kernel_labels = next(gen_kernel_labels)
                filtered[center] = apply_kernel(center, kernel, image_padded)
                kernel_overlay = visualize_kernel(kernel_labels, image_padded)
                
                image_cache.append((kernel_overlay, filtered))
                
            image_pair = [remove_padding(each, kernel) for each in image_cache[i_step]]
            imshow_pair(image_pair, **kwargs)
            plt.show()
        
        interact(convolution_step, i_step=(0, image.size - 1, 1))
    
    interact(update_convolution, kernel_name=kernel_selector);

# Full process
def convolution_full(ax, image, kernel, **kwargs):
    # Initialize generator since we're only ever going to iterate over
    # a pixel once. The cached result is used, if we step back.
    gen_kernel_labels = iter_kernel_labels(image, kernel)

    image_cache = []
    image_padded = add_padding(image, kernel)
    # Plot original image and kernel-overlay next to filtered image.

    for i_step in range(image.size-1):

        # For the first step (`i_step == 0`), the original image is the
        # filtered image; after that we look in the cache, which stores
        # (`kernel_overlay`, `filtered`).
        filtered_prev = image_padded if i_step == 0 else image_cache[-1][1]
        # We don't want to overwrite the previously filtered image:
        filtered = filtered_prev.copy()

        # Get the labels used to visualize the kernel
        center, kernel_labels = next(gen_kernel_labels)
        # Modify the pixel value at the kernel center
        filtered[center] = apply_kernel(center, kernel, image_padded)
        # Take the original image and overlay our kernel visualization
        kernel_overlay = visualize_kernel(kernel_labels, image_padded)
        # Save images for reuse.
        image_cache.append((kernel_overlay, filtered))

    # Remove padding we added to deal with boundary conditions
    # (Loop since each step has 2 images)
    image_triple = [remove_padding(each, kernel)
                  for each in image_cache[i_step]]
    image_triple.insert(1,kernel)
    imshow_triple(ax, image_triple, **kwargs)
  • Different kernels can detect different types of patterns in the image

Hide code cell source
horizontal_edge_kernel = np.array([[ 1,  2,  1],
                                   [ 0,  0,  0],
                                   [-1, -2, -1]])
diagonal_edge_kernel = np.array([[1, 0, 0],
                                 [0, 1, 0],
                                 [0, 0, 1]])
edge_detect_kernel = np.array([[-1, -1, -1],
                               [-1,  8, -1],
                               [-1, -1, -1]])
all_kernels = {"horizontal": horizontal_edge_kernel,
               "diagonal": diagonal_edge_kernel,
               "edge_detect":edge_detect_kernel}
Hide code cell source
mnist_data = oml.datasets.get_dataset(554) # Download MNIST data
# Get the predictors X and the labels y
X_mnist, y_mnist, c, a = mnist_data.get_data(dataset_format='array', target=mnist_data.default_target_attribute); 
image = X_mnist[1].reshape((28, 28))
image = (image - np.min(image))/np.ptp(image) # Normalize

titles = ('Image and kernel', 'Filtered image')
convolution_demo(image, all_kernels, vmin=-4, vmax=4, titles=titles, cmap='gray_r');
Hide code cell source
if not interactive:
    fig, axs = plt.subplots(3,  3, figsize=(5*fig_scale, 5*fig_scale))
    titles = ('Image and kernel', 'Hor. edge filter', 'Filtered image')
    convolution_full(axs[0,:], image, horizontal_edge_kernel, vmin=-4, vmax=4, titles=titles, cmap='gray_r')
    titles = ('Image and kernel', 'Edge detect filter', 'Filtered image')
    convolution_full(axs[1,:], image, edge_detect_kernel, vmin=-4, vmax=4, titles=titles, cmap='gray_r')
    titles = ('Image and kernel', 'Diag. edge filter', 'Filtered image')
    convolution_full(axs[2,:], image, diagonal_edge_kernel, vmin=-4, vmax=4, titles=titles, cmap='gray_r')
    plt.tight_layout()

Demonstration on Fashion-MNIST#

Hide code cell source
fmnist_data = oml.datasets.get_dataset(40996) # Download FMNIST data
# Get the predictors X and the labels y
X_fm, y_fm, _, _ = fmnist_data.get_data(dataset_format='array', target=fmnist_data.default_target_attribute)
fm_classes = {0:"T-shirt/top", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 
              6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle boot"}
Hide code cell source
# build a list of figures for plotting
def buildFigureList(fig, subfiglist, titles, length):
    for i in range(0,length):
        pixels = np.array(subfiglist[i], dtype='float')
        pixels = pixels.reshape((28, 28))
        a=fig.add_subplot(1,length,i+1)
        imgplot =plt.imshow(pixels, cmap='gray_r')
        a.set_title(fm_classes[titles[i]], fontsize=6)
        a.axes.get_xaxis().set_visible(False)
        a.axes.get_yaxis().set_visible(False)
    return

subfiglist = []
titles=[]

for i in range(0,7):
    subfiglist.append(X_fm[i])
    titles.append(i)

buildFigureList(plt.figure(1),subfiglist, titles, 7)
plt.show()
../_images/650b1bb80a2a1e51df84109f1a1cac36f2085a2243ab3d2e0b7c3035805d218f.png

Demonstration

Hide code cell source
def normalize_image(X):
    image = X.reshape((28, 28))
    return (image - np.min(image))/np.ptp(image) # Normalize

image = normalize_image(X_fm[3])
demo2 = convolution_demo(image, all_kernels,
                 vmin=-4, vmax=4, cmap='gray_r');
Hide code cell source
if not interactive:
    fig, axs = plt.subplots(3, 3, figsize=(5*fig_scale, 5*fig_scale))
    titles = ('Image and kernel', 'Hor. edge filter', 'Filtered image')
    convolution_full(axs[0,:], image, horizontal_edge_kernel, vmin=-4, vmax=4, titles=titles, cmap='gray_r')
    titles = ('Image and kernel', 'Diag. edge filter', 'Filtered image')
    convolution_full(axs[1,:], image, diagonal_edge_kernel, vmin=-4, vmax=4, titles=titles, cmap='gray_r')
    titles = ('Image and kernel', 'Edge detect filter', 'Filtered image')
    convolution_full(axs[2,:], image, edge_detect_kernel, vmin=-4, vmax=4, titles=titles, cmap='gray_r')
    plt.tight_layout()

Image convolution in practice#

  • How do we know which filters are best for a given image?

  • Families of kernels (or filter banks ) can be run on every image

    • Gabor, Sobel, Haar Wavelets,…

  • Gabor filters: Wave patterns generated by changing:

    • Frequency: narrow or wide ondulations

    • Theta: angle (direction) of the wave

    • Sigma: resolution (size of the filter)

Demonstration

Hide code cell source
from scipy import ndimage as ndi
from skimage import data
from skimage.util import img_as_float
from skimage.filters import gabor_kernel

# Gabor Filters.
@interact
def demoGabor(frequency=(0.01,1,0.05), theta=(0,3.14,0.1), sigma=(0,5,0.1)):
    plt.gray()
    plt.imshow(np.real(gabor_kernel(frequency=frequency, theta=theta, sigma_x=sigma, sigma_y=sigma)), interpolation='nearest', extent=[-1, 1, -1, 1])
    plt.title(f'freq: {round(frequency,2)}, theta: {round(theta,2)}, sigma: {round(sigma,2)}', fontdict={'fontsize':14*fig_scale})
    plt.xticks([])
    plt.yticks([])
../_images/ad4e97034f1151ae4de531790b8b326cba43635c82686f9f55132a5239026d90.png
Hide code cell source
if not interactive:
    plt.subplot(1, 3, 1)
    demoGabor(frequency=0.16, theta=1.2, sigma=4.0)
    plt.subplot(1, 3, 2)
    demoGabor(frequency=0.31, theta=0, sigma=3.6)
    plt.subplot(1, 3, 3)
    demoGabor(frequency=0.36, theta=1.6, sigma=1.3)
    plt.tight_layout()

Demonstration on the Fashion-MNIST data

Hide code cell source
# Calculate the magnitude of the Gabor filter response given a kernel and an imput image
def magnitude(image, kernel):
    image = (image - image.mean()) / image.std() # Normalize images
    return np.sqrt(ndi.convolve(image, np.real(kernel), mode='wrap')**2 +
                   ndi.convolve(image, np.imag(kernel), mode='wrap')**2)
Hide code cell source
@interact
def demoGabor2(frequency=(0.01,1,0.05), theta=(0,3.14,0.1), sigma=(0,5,0.1)):
    plt.subplot(131)
    plt.title('Original', fontdict={'fontsize':24*fig_scale})
    plt.imshow(image)
    plt.xticks([])
    plt.yticks([])
    plt.subplot(132)
    plt.title('Gabor kernel', fontdict={'fontsize':24*fig_scale})
    plt.imshow(np.real(gabor_kernel(frequency=frequency, theta=theta, sigma_x=sigma, sigma_y=sigma)), interpolation='nearest')
    plt.xticks([])
    plt.yticks([])
    plt.subplot(133)
    plt.title('Response magnitude', fontdict={'fontsize':24*fig_scale})
    plt.imshow(np.real(magnitude(image, gabor_kernel(frequency=frequency, theta=theta, sigma_x=sigma, sigma_y=sigma))), interpolation='nearest')
    plt.tight_layout()
    plt.xticks([])
    plt.yticks([])
    plt.show()
Hide code cell source
if not interactive:
    demoGabor2(frequency=0.16, theta=1.4, sigma=1.2)

Filter banks#

  • Different filters detect different edges, shapes,…

  • Not all seem useful

Hide code cell source
# More images
# Fetch some Fashion-MNIST images
boot = X_fm[0].reshape(28, 28)
shirt = X_fm[1].reshape(28, 28)
dress = X_fm[2].reshape(28, 28)
image_names = ('boot', 'shirt', 'dress')
images = (boot, shirt, dress)

def plot_filter_bank(images):
    # Create a set of kernels, apply them to each image, store the results
    results = []
    kernel_params = []
    for theta in (0, 1):
        theta = theta / 4. * np.pi
        for frequency in (0.1, 0.2):
            for sigma in (1, 3):
                kernel = gabor_kernel(frequency, theta=theta,sigma_x=sigma,sigma_y=sigma)
                params = 'theta=%.2f,\nfrequency=%.2f\nsigma=%.2f' % (theta, frequency, sigma)
                kernel_params.append(params)
                results.append((kernel, [magnitude(img, kernel) for img in images]))

    # Plotting
    fig, axes = plt.subplots(nrows=4, ncols=9, figsize=(14*fig_scale, 8*fig_scale))
    plt.gray()
    #fig.suptitle('Image responses for Gabor filter kernels', fontsize=12)
    axes[0][0].axis('off')

    for label, img, ax in zip(image_names, images, axes[1:]):
        axs = ax[0]
        axs.imshow(img)
        axs.set_ylabel(label, fontsize=12*fig_scale)
        axs.set_xticks([]) # Remove axis ticks 
        axs.set_yticks([])
        
    # Plot Gabor kernel
    col = 1
    for label, (kernel, magnitudes), ax_col in zip(kernel_params, results, axes[0][1:]):
        ax_col.imshow(np.real(kernel), interpolation='nearest') # Plot kernel
        ax_col.set_title(label, fontsize=10*fig_scale)
        ax_col.axis('off')
        
        # Plot Gabor responses with the contrast normalized for each filter
        vmin = np.min(magnitudes)
        vmax = np.max(magnitudes)
        for patch, ax in zip(magnitudes, axes.T[col][1:]):
            ax.imshow(patch, vmin=vmin, vmax=vmax) # Plot convolutions
            ax.axis('off')
        col += 1
    
    plt.show()

plot_filter_bank(images)
../_images/a4940096f640d5b30983e7b9df214fc89e4641aff3a191872d4f9f758098f457.png

Convolutional neural nets#

  • Finding relationships between individual pixels and the correct class is hard

  • Simplify the problem by decomposing it into smaller problems

  • First, discover ‘local’ patterns (edges, lines, endpoints)

  • Representing such local patterns as features makes it easier to learn from them

    • Deeper layers will do that for us

  • We could use convolutions, but how to choose the filters?

ml

Convolutional Neural Networks (ConvNets)#

  • Instead of manually designing the filters, we can also learn them based on data

    • Choose filter sizes (manually), initialize with small random weights

  • Forward pass: Convolutional layer slides the filter over the input, generates the output

  • Backward pass: Update the filter weights according to the loss gradients

  • Illustration for 1 filter:

ml

Convolutional layers: Feature maps#

  • One filter is not sufficient to detect all relevant patterns in an image

  • A convolutional layer applies and learns \(d\) filters in parallel

  • Slide \(d\) filters across the input image (in parallel) -> a (1x1xd) output per patch

  • Reassemble into a feature map with \(d\) ‘channels’, a (width x height x d) tensor.

ml

Border effects (zero padding)#

  • Consider a 5x5 image and a 3x3 filter: there are only 9 possible locations, hence the output is a 3x3 feature map

  • If we want to maintain the image size, we use zero-padding, adding 0’s all around the input tensor.

ml ml

Undersampling (striding)#

  • Sometimes, we want to downsample a high-resolution image

    • Faster processing, less noisy (hence less overfitting)

    • Forces the model to summarize information in (smaller) feature maps

  • One approach is to skip values during the convolution

    • Distance between 2 windows: stride length

  • Example with stride length 2 (without padding):

ml

Max-pooling#

  • Another approach to shrink the input tensors is max-pooling :

    • Run a filter with a fixed stride length over the image

      • Usually 2x2 filters and stride lenght 2

    • The filter simply returns the max (or avg ) of all values

  • Agressively reduces the number of weights (less overfitting)

ml

Receptive field#

  • Receptive field: how much each output neuron ‘sees’ of the input image

  • Translation invariance: shifting the input does not affect the output

    • Large receptive field -> neurons can ‘see’ patterns anywhere in the input

  • \(nxn\) convolutions only increase the receptive field by \(n+2\) each layer

  • Maxpooling doubles the receptive field without deepening the network

import matplotlib.patches as patches

def draw_grid(ax, size, offset):
    """Draws a grid without text labels"""
    for i in range(size):
        for j in range(size):
            ax.add_patch(patches.Rectangle((j + offset[0], -i + offset[1]), 1, 1, 
                                           fill=False, edgecolor='gray', linewidth=1))

def highlight_region(ax, positions, offset, color, alpha=0.3):
    """Highlights a specific region in the grid"""
    for x, y in positions:
        ax.add_patch(patches.Rectangle((x + offset[0], -y + offset[1]), 1, 1, fill=True, color=color, alpha=alpha))

def draw_connection_hull(ax, points, color, alpha):
    """Draws a polygon representing the hull of connection lines"""
    ax.add_patch(patches.Polygon(points, closed=True, facecolor=color, alpha=alpha, edgecolor=None))
    
def add_titles(ax, option):
    """Adds titles above each matrix"""
    titles = ["Input", option, "Output_1", "Kernel_2", "Output_2"]
    positions = [(0, 1.5), (9, 1.5), (15, 1.5), (20, 1.5), (24, 1.5)]
    
    for title, (x, y) in zip(titles, positions):
        ax.text(x, y, title, fontsize=12, fontweight='bold', ha='left')

layer_options = ['3x3 Kernel', '3x3 Kernel, Stride 2', '5x5 Kernel', 'MaxPool 2x2']
layer_options2 = ['3x3 Kernel', '3x3 Kernel, Dilation 2']

@interact
def visualize_receptive_field(option=layer_options):
    fig, ax = plt.subplots(figsize=(18, 6))
    ax.set_xlim(-2, 26)
    ax.set_ylim(-9, 2)
    ax.axis('off')
    add_titles(ax, option)
    kernel_size = 0
    
    grids = [(8, (0, 0)), (4, (15, 0)), (3, (20, 0)), (2, (24, 0))]
    
    single_output_rf = [(0, 0)]
    for size, offset in grids:
        draw_grid(ax, size, offset)
    
    if option == 'MaxPool 2x2':
        full_input_rf = [(x, y) for x in range(6) for y in range(6)]
        highlight_region(ax, full_input_rf, (0, 0), 'green', alpha=0.3)
    else:
        kernel_size = 3 if option.startswith('3x3 Kernel') else 5
        draw_grid(ax, kernel_size, (9, 0))
        
        input_highlight_size = kernel_size + 2
        if option == '3x3 Kernel, Stride 2' or option == '3x3 Kernel, Dilation 2':
            input_highlight_size = kernel_size + 4

        full_input_rf = [(x, y) for x in range(input_highlight_size) for y in range(input_highlight_size)]
        kernel_1 = [(x, y) for x in range(kernel_size) for y in range(kernel_size)]
        kernel_rf = kernel_1
        if option == '3x3 Kernel, Dilation 2':
            kernel_rf = [(x*2, y*2) for x in range(kernel_size) for y in range(kernel_size)]

        highlight_region(ax, full_input_rf, (0, 0), 'green')
        highlight_region(ax, kernel_rf, (0, 0), 'blue')
        highlight_region(ax, kernel_1, (9, 0), 'blue')
        highlight_region(ax, single_output_rf, (15, 0), 'blue')
    
    kernel2_rf = [(x, y) for x in range(3) for y in range(3)]
    
    highlight_region(ax, kernel2_rf, (15, 0), 'green')
    highlight_region(ax, kernel2_rf, (20, 0), 'green')
    highlight_region(ax, single_output_rf, (24, 0), 'green')
    
    connection_hulls = [
        ([(23, -2), (23, 1), (24, 1), (24, 0)], 'green', 0.1),
        ([(18, -2), (18, 1), (20, 1), (20, -2)], 'green', 0.1)
    ]
    
    kernel_fp = kernel_size * 2 - 1 if option == '3x3 Kernel, Dilation 2' else kernel_size

    if option != 'MaxPool 2x2':
        connection_hulls.extend([
            ([(kernel_fp, 1-kernel_fp), (kernel_fp, 1), (9, 1), (9, 1-kernel_size)], 'blue', 0.1),
            ([(9+kernel_size, 1-kernel_size), (9+kernel_size, 1), (15, 1), (15, 0)], 'blue', 0.1)
        ])
    else:
        connection_hulls.extend([
            ([(6, -5), (6, 1), (15, 1), (15, -2)], 'green', 0.1),
        ])        
    
    for points, color, alpha in connection_hulls:
        draw_connection_hull(ax, points, color, alpha)
    
    plt.show()
if not interactive:
    for option in layer_options[0::3]:
        visualize_receptive_field(option=option)

Dilated convolutions#

  • Downsample by introducing ‘gaps’ between filter elements by spacing them out

  • Increases the receptive field exponentially

  • Doesn’t need extra parameters or computation (unlike larger filters)

  • Retains feature map size (unlike pooling)

@interact
def visualize_receptive_field2(option=layer_options2):
    visualize_receptive_field(option)
if not interactive:
    visualize_receptive_field(option=layer_options2[1])

Convolutional nets in practice#

  • Use multiple convolutional layers to learn patterns at different levels of abstraction

    • Find local patterns first (e.g. edges), then patterns across those patterns

  • Use MaxPooling layers to reduce resolution, increase translation invariance

  • Use sufficient filters in the first layer (otherwise information gets lost)

  • In deeper layers, use increasingly more filters

    • Preserve information about the input as resolution descreases

    • Avoid decreasing the number of activations (resolution x nr of filters)

  • For very deep nets, add skip connections to preserve information (and gradients)

    • Sums up outputs of earlier layers to those of later layers (with same dimensions)

Example with PyTorch#

  • Conv2d for 2D convolutional layers

    • Grayscale image: 1 in_channels

    • 32 filters: 32 out_channels, 3x3 size

    • Deeper layers use 64 filters

    • ReLU activation, no padding

    • MaxPool2d for max-pooling, 2x2

model = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3),
    nn.ReLU()
)
import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=0),
    nn.ReLU(), 
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=0),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=0),
    nn.ReLU()
)
  • Observe how the input image on 1x28x28 is transformed to a 64x3x3 feature map

    • In pytorch, shapes are (batch_size, channels, height, width)

  • Conv2d parameters = (kernel size^2 × input channels + 1) × output channels

  • No zero-padding: every output is 2 pixels less in every dimension

  • After every MaxPooling, resolution halved in every dimension

from torchinfo import summary
summary(model, input_size=(1, 1, 28, 28))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Sequential                               [1, 64, 3, 3]             --
├─Conv2d: 1-1                            [1, 32, 26, 26]           320
├─ReLU: 1-2                              [1, 32, 26, 26]           --
├─MaxPool2d: 1-3                         [1, 32, 13, 13]           --
├─Conv2d: 1-4                            [1, 64, 11, 11]           18,496
├─ReLU: 1-5                              [1, 64, 11, 11]           --
├─MaxPool2d: 1-6                         [1, 64, 5, 5]             --
├─Conv2d: 1-7                            [1, 64, 3, 3]             36,928
├─ReLU: 1-8                              [1, 64, 3, 3]             --
==========================================================================================
Total params: 55,744
Trainable params: 55,744
Non-trainable params: 0
Total mult-adds (M): 2.79
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.24
Params size (MB): 0.22
Estimated Total Size (MB): 0.47
==========================================================================================
  • To classify the images, we still need a linear and output layer.

  • We flatten the 3x3x64 feature map to a vector of size 576

model = nn.Sequential(
    ...
    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=0),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(64 * 3 * 3, 64),
    nn.ReLU(),
    nn.Linear(64, 10)
)
Hide code cell source
model = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=0),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=0),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=0),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(64 * 3 * 3, 64),
    nn.ReLU(),
    nn.Linear(64, 10)
)

Complete model. Flattening adds a lot of weights!

Hide code cell source
summary(model, input_size=(1, 1, 28, 28))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Sequential                               [1, 10]                   --
├─Conv2d: 1-1                            [1, 32, 26, 26]           320
├─ReLU: 1-2                              [1, 32, 26, 26]           --
├─MaxPool2d: 1-3                         [1, 32, 13, 13]           --
├─Conv2d: 1-4                            [1, 64, 11, 11]           18,496
├─ReLU: 1-5                              [1, 64, 11, 11]           --
├─MaxPool2d: 1-6                         [1, 64, 5, 5]             --
├─Conv2d: 1-7                            [1, 64, 3, 3]             36,928
├─ReLU: 1-8                              [1, 64, 3, 3]             --
├─Flatten: 1-9                           [1, 576]                  --
├─Linear: 1-10                           [1, 64]                   36,928
├─ReLU: 1-11                             [1, 64]                   --
├─Linear: 1-12                           [1, 10]                   650
==========================================================================================
Total params: 93,322
Trainable params: 93,322
Non-trainable params: 0
Total mult-adds (M): 2.82
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.24
Params size (MB): 0.37
Estimated Total Size (MB): 0.62
==========================================================================================

Global Average Pooling (GAP)#

  • Instead of flattening, we do GAP: returns average of each activation map

  • We can drop the hidden dense layer: number of outputs > number of classes

model = nn.Sequential(...
    nn.AdaptiveAvgPool2d(1), # Global Average Pooling
    nn.Flatten(),            # Convert (batch, 64, 1, 1) -> (batch, 64)
    nn.Linear(64, 10))       # Output layer for 10 classes
ml
  • With GlobalAveragePooling: much fewer weights to learn

  • Use with caution: this destroys the location information learned by the CNN

  • Not ideal for tasks such as object localization

Hide code cell source
model = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=0),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=0),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=0),
    nn.ReLU(),
    nn.AdaptiveAvgPool2d(1),  # Global Average Pooling (GAP)
    nn.Flatten(),  # Convert (batch, 64, 1, 1) -> (batch, 64)
    nn.Linear(64, 10)  # Output layer for 10 classes
)
summary(model, input_size=(1, 1, 28, 28))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Sequential                               [1, 10]                   --
├─Conv2d: 1-1                            [1, 32, 26, 26]           320
├─ReLU: 1-2                              [1, 32, 26, 26]           --
├─MaxPool2d: 1-3                         [1, 32, 13, 13]           --
├─Conv2d: 1-4                            [1, 64, 11, 11]           18,496
├─ReLU: 1-5                              [1, 64, 11, 11]           --
├─MaxPool2d: 1-6                         [1, 64, 5, 5]             --
├─Conv2d: 1-7                            [1, 64, 3, 3]             36,928
├─ReLU: 1-8                              [1, 64, 3, 3]             --
├─AdaptiveAvgPool2d: 1-9                 [1, 64, 1, 1]             --
├─Flatten: 1-10                          [1, 64]                   --
├─Linear: 1-11                           [1, 10]                   650
==========================================================================================
Total params: 56,394
Trainable params: 56,394
Non-trainable params: 0
Total mult-adds (M): 2.79
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.24
Params size (MB): 0.23
Estimated Total Size (MB): 0.47
==========================================================================================

Run the model on MNIST dataset

  • Train and test as usual: 99% accuracy

    • Compared to 97,8% accuracy with the dense architecture

    • Flatten and GlobalAveragePooling yield similar performance

import pytorch_lightning as pl

# Keeps a history of scores to make plotting easier
class MetricTracker(pl.Callback):
    def __init__(self):
        super().__init__()
        self.history = {
            "train_loss": [],
            "train_acc": [],
            "val_loss": [],
            "val_acc": []
        }
        self.first_validation = True  # Flag to ignore first validation step

    def on_train_epoch_end(self, trainer, pl_module):
        """Collects training metrics at the end of each epoch"""
        train_loss = trainer.callback_metrics.get("train_loss")
        train_acc = trainer.callback_metrics.get("train_acc")

        if train_loss is not None:
            self.history["train_loss"].append(train_loss.cpu().item())
        if train_acc is not None:
            self.history["train_acc"].append(train_acc.cpu().item())

    def on_validation_epoch_end(self, trainer, pl_module):
        """Collects validation metrics at the end of each epoch"""
        if self.first_validation:  
            self.first_validation = False  # Skip first validation logging
            return  

        val_loss = trainer.callback_metrics.get("val_loss")
        val_acc = trainer.callback_metrics.get("val_acc")

        if val_loss is not None:
            self.history["val_loss"].append(val_loss.cpu().item())
        if val_acc is not None:
            self.history["val_acc"].append(val_acc.cpu().item())
            
def plot_training(history):
    plt.figure(figsize=(12, 4))  # Increased figure size

    # Plot Loss
    plt.subplot(1, 2, 1)
    plt.plot(history["train_loss"], label="Train Loss", marker='o', lw=2)
    plt.plot(history["val_loss"], label="Validation Loss", marker='o', lw=2)
    plt.xlabel("Epochs", fontsize=14)  # Larger font size
    plt.ylabel("Loss", fontsize=14)
    plt.title("Loss vs. Epochs", fontsize=16, fontweight="bold")
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.legend(fontsize=12)

    # Plot Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(history["train_acc"], label="Train Accuracy", marker='o', lw=2)
    plt.plot(history["val_acc"], label="Validation Accuracy", marker='o', lw=2)
    plt.xlabel("Epochs", fontsize=14)
    plt.ylabel("Accuracy", fontsize=14)
    plt.title("Accuracy vs. Epochs", fontsize=16, fontweight="bold")
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.legend(fontsize=12)

    plt.tight_layout()  # Adjust layout for readability
    plt.show()
Hide code cell source
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, random_split
from torchmetrics.functional import accuracy

# Model in Pytorch Lightning
class MNISTModel(pl.LightningModule):
    def __init__(self, learning_rate=0.001):
        super().__init__()
        self.save_hyperparameters()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 64, kernel_size=3, padding=0),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(64, 10)
        )
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    # Logging of loss and accuracy for later plotting
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = accuracy(logits, y, task="multiclass", num_classes=10)
        self.log("train_loss", loss, prog_bar=True, on_epoch=True)
        self.log("train_acc", acc, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = accuracy(logits, y, task="multiclass", num_classes=10)
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        self.log("val_acc", acc, prog_bar=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

# Compute mean and std to normalize the data
# Couldn't find a way to do this automatically in PyTorch :(
# Normalization is not strictly needed, but speeds up convergence
dataset = datasets.MNIST(root=".", train=True, transform=transforms.ToTensor(), download=True)
loader = torch.utils.data.DataLoader(dataset, batch_size=1000, num_workers=4, shuffle=False)
mean = torch.mean(torch.stack([batch[0].mean() for batch in loader]))
std = torch.mean(torch.stack([batch[0].std() for batch in loader]))

# Loading the data. We'll discuss data loaders again soon.
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((mean,), (std,))  # Normalize MNIST. Make more general?
        ])

    def prepare_data(self):
        datasets.MNIST(root=".", train=True, download=True)  # Downloads dataset

    def setup(self, stage=None):
        full_train = datasets.MNIST(root=".", train=True, transform=self.transform)
        self.train, self.val = random_split(full_train, [55000, 5000])
        self.test = datasets.MNIST(root=".", train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size, num_workers=4)

    
# Initialize data & model
pl.seed_everything(42)  # Ensure reproducibility
data_module = MNISTDataModule(batch_size=64)
model = MNISTModel(learning_rate=0.001)

# Trainer with logging & checkpointing
accelerator = "cpu"
if torch.backends.mps.is_available():
    accelerator = "mps"
if torch.cuda.is_available():
    accelerator = "gpu"

metric_tracker = MetricTracker()  # Callback to track per-epoch metrics

trainer = pl.Trainer(
    max_epochs=10,  # Train for 10 epochs
    accelerator=accelerator,
    devices="auto",
    log_every_n_steps=10,
    deterministic=True,
    callbacks=[metric_tracker]  # Attach callback to trainer
)

if histories and histories["mnist"]:
    history = histories["mnist"]
else:
    trainer.fit(model, datamodule=data_module)
    history = metric_tracker.history

# Test after training (sanity check)
# trainer.test(model, datamodule=data_module)
Seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
plot_training(history)
../_images/a6c6217e5beab525795f05b1b4623c3c61fa4ca6fc2879b8eaaf974525f67fd8.png

Cats vs Dogs#

  • A more realistic dataset: Cats vs Dogs

    • Colored JPEG images, different sizes

    • Not nicely centered, translation invariance is important

  • Preprocessing

    • Decode JPEG images to floating-point tensors

    • Rescale pixel values to [0,1]

    • Resize images to 150x150 pixels

Uncomment to run from scratch

# TODO: upload dataset to OpenML so we can avoid the manual steps.

import os, shutil 
# Download data from https://www.kaggle.com/c/dogs-vs-cats/data
# Uncompress `train.zip` into the `original_dataset_dir`
original_dataset_dir = '../data/cats-vs-dogs_original'

# The directory where we will
# store our smaller dataset
train_dir = os.path.join(data_dir, 'train')
validation_dir = os.path.join(data_dir, 'validation')

if not os.path.exists(data_dir):
    os.mkdir(data_dir)
    os.mkdir(train_dir)
    os.mkdir(validation_dir)
    
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')

if not os.path.exists(train_cats_dir):
    os.mkdir(train_cats_dir)
    os.mkdir(train_dogs_dir)
    os.mkdir(validation_cats_dir)
    os.mkdir(validation_dogs_dir)

# Copy first 2000 cat images to train_cats_dir
fnames = ['cat.{}.jpg'.format(i) for i in range(2000)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(train_cats_dir, fname)
    shutil.copyfile(src, dst)
    
# Copy next 1000 cat images to validation_cats_dir
fnames = ['cat.{}.jpg'.format(i) for i in range(2000, 3000)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(validation_cats_dir, fname)
    shutil.copyfile(src, dst)
    
# Copy first 2000 dog images to train_dogs_dir
fnames = ['dog.{}.jpg'.format(i) for i in range(2000)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(train_dogs_dir, fname)
    shutil.copyfile(src, dst)
    
# Copy next 1000 dog images to validation_dogs_dir
fnames = ['dog.{}.jpg'.format(i) for i in range(2000, 3000)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(validation_dogs_dir, fname)
    shutil.copyfile(src, dst)
import random

# Set random seed for reproducibility
def seed_everything(seed=42):
    pl.seed_everything(seed)  # Sets seed for PyTorch Lightning
    torch.manual_seed(seed)  # PyTorch
    torch.cuda.manual_seed_all(seed)  # CUDA (if available)
    np.random.seed(seed)  # NumPy
    random.seed(seed)  # Python random module
    torch.backends.cudnn.deterministic = True  # Ensures reproducibility in CNNs
    torch.backends.cudnn.benchmark = False  # Ensures consistency

seed_everything(42)  # Set global seed

class CatDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=20, img_size=(150, 150)):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.img_size = img_size

        # Define image transformations
        self.transform = transforms.Compose([
            transforms.Resize(self.img_size),  # Resize to 150x150
            transforms.ToTensor(),  # Convert to tensor (also scales 0-1)
        ])

    def setup(self, stage=None):
        """Load datasets"""
        train_dir = os.path.join(self.data_dir, "train")
        val_dir = os.path.join(self.data_dir, "validation")

        self.train_dataset = datasets.ImageFolder(root=train_dir, transform=self.transform)
        self.val_dataset = datasets.ImageFolder(root=val_dir, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)

# ----------------------------
# Load dataset and visualize a batch
# ----------------------------
data_module = CatDataModule(data_dir=data_dir)
data_module.setup()
train_loader = data_module.train_dataloader()
Seed set to 42
# Get a batch of data
data_batch, labels_batch = next(iter(train_loader))

# Visualize images
plt.figure(figsize=(10, 5))
for i in range(7):
    plt.subplot(1, 7, i + 1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(data_batch[i].permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C)
    plt.title("Cat" if labels_batch[i] == 0 else "Dog", fontsize=16)
plt.tight_layout()
plt.show()
../_images/dcad454f1a540dc6eca4946dd21ede406ca3949e2336ede266b037527ee288cf.png

Data loader#

  • We create a Pytorch Lightning DataModule to do preprocessing and data loading

class ImageDataModule(pl.LightningDataModule):
  def __init__(self, data_dir, batch_size=20, img_size=(150, 150)):
    super().__init__()
    self.transform = transforms.Compose([
      transforms.Resize(self.img_size),  # Resize to 150x150
      transforms.ToTensor()])  # Convert to tensor (also scales 0-1)
  def setup(self, stage=None):
    self.train_dataset = datasets.ImageFolder(root=train_dir, transform=self.transform)
    self.val_dataset = datasets.ImageFolder(root=val_dir, transform=self.transform)
  def train_dataloader(self):
    return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
  def val_dataloader(self):
    return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)
from torchmetrics.classification import Accuracy

# Model in PyTorch Lightning
class CatImageClassifier(pl.LightningModule):
    def __init__(self, learning_rate=0.001):
        super().__init__()
        self.save_hyperparameters()

        # Define convolutional layers
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(32, 64, kernel_size=3, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 128, kernel_size=3, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.AdaptiveAvgPool2d(1)  # GAP replaces Flatten()
        )

        # Fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear(128, 512),  # GAP outputs (batch, 128, 1, 1) → Flatten to (batch, 128)
            nn.ReLU(),
            nn.Linear(512, 1)  # Binary classification (1 output neuron)
        )

        self.loss_fn = nn.BCEWithLogitsLoss()
        self.accuracy = Accuracy(task="binary")

    def forward(self, x):
        x = self.conv_layers(x)  # Convolutions + GAP
        x = x.view(x.size(0), -1)  # Flatten from (batch, 128, 1, 1) → (batch, 128)
        x = self.fc_layers(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x).squeeze(1)  # Remove extra dimension
        loss = self.loss_fn(logits, y.float())  # BCE loss requires float labels

        preds = torch.sigmoid(logits)  # Convert logits to probabilities
        acc = self.accuracy(preds, y)

        self.log("train_loss", loss, prog_bar=True, on_epoch=True)
        self.log("train_acc", acc, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x).squeeze(1)
        loss = self.loss_fn(logits, y.float())

        preds = torch.sigmoid(logits)
        acc = self.accuracy(preds, y)

        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        self.log("val_acc", acc, prog_bar=True, on_epoch=True)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

Model#

Since the images are more complex, we add another convolutional layer and increase the number of filters to 128.

Hide code cell source
model = CatImageClassifier()
summary(model, input_size=(1, 3, 150, 150))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
CatImageClassifier                       [1, 1]                    --
├─Sequential: 1-1                        [1, 128, 1, 1]            --
│    └─Conv2d: 2-1                       [1, 32, 148, 148]         896
│    └─ReLU: 2-2                         [1, 32, 148, 148]         --
│    └─MaxPool2d: 2-3                    [1, 32, 74, 74]           --
│    └─Conv2d: 2-4                       [1, 64, 72, 72]           18,496
│    └─ReLU: 2-5                         [1, 64, 72, 72]           --
│    └─MaxPool2d: 2-6                    [1, 64, 36, 36]           --
│    └─Conv2d: 2-7                       [1, 128, 34, 34]          73,856
│    └─ReLU: 2-8                         [1, 128, 34, 34]          --
│    └─MaxPool2d: 2-9                    [1, 128, 17, 17]          --
│    └─Conv2d: 2-10                      [1, 128, 15, 15]          147,584
│    └─ReLU: 2-11                        [1, 128, 15, 15]          --
│    └─MaxPool2d: 2-12                   [1, 128, 7, 7]            --
│    └─AdaptiveAvgPool2d: 2-13           [1, 128, 1, 1]            --
├─Sequential: 1-2                        [1, 1]                    --
│    └─Linear: 2-14                      [1, 512]                  66,048
│    └─ReLU: 2-15                        [1, 512]                  --
│    └─Linear: 2-16                      [1, 1]                    513
==========================================================================================
Total params: 307,393
Trainable params: 307,393
Non-trainable params: 0
Total mult-adds (M): 234.16
==========================================================================================
Input size (MB): 0.27
Forward/backward pass size (MB): 9.68
Params size (MB): 1.23
Estimated Total Size (MB): 11.18
==========================================================================================

Training#

  • We use a Trainer module (from PyTorch Lightning) to simplify training

trainer = pl.Trainer(
    max_epochs=20,        # Train for 20 epochs
    accelerator="gpu",    # Move data and model to GPU
    devices="auto",       # Number of GPUs
    deterministic=True,   # Set random seeds, for reproducibility
    callbacks=[metric_tracker,      # Callback for logging loss and acc
               checkpoint_callback] # Callback for logging weights
)
trainer.fit(model, datamodule=data_module)
  • Tip: to store the best model weights, you can add a ModelCheckpoint callback

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",   # Save model with lowest val. loss
    mode="min",           # "min" for loss, "max" for accuracy
    save_top_k=1,         # Keep only the best model
    dirpath="weights/",   # Directory to save checkpoints
    filename="cat_model", # File name pattern
)

The model learns well for the first 20 epochs, but then starts overfitting a lot!

Hide code cell source
from pytorch_lightning.callbacks import ModelCheckpoint

# Train Cat model
pl.seed_everything(42)  # Ensure reproducibility
data_module = CatDataModule(data_dir, batch_size=64)
model = CatImageClassifier(learning_rate=0.001)
metric_tracker = MetricTracker()  # Callback to track per-epoch metrics
from pytorch_lightning.callbacks import ModelCheckpoint

# Define checkpoint callback to save the best model
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",  # Saves model with lowest validation loss
    mode="min",  # "min" for loss, "max" for accuracy
    save_top_k=1,  # Keep only the best model
    dirpath="../data/checkpoints/",  # Directory to save checkpoints
    filename="cat_model",  # File name pattern
)

trainer = pl.Trainer(
    max_epochs=50,  # Train for 20 epochs
    accelerator=accelerator,
    devices="auto",
    log_every_n_steps=10,
    deterministic=True,
    callbacks=[metric_tracker, checkpoint_callback]  # Attach callback to trainer
)

if histories and histories["cat"]:
    history_cat = histories["cat"]
else:
    trainer.fit(model, datamodule=data_module)
    history_cat = metric_tracker.history
Seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
plot_training(history_cat)
../_images/462ba70584ea639fd37c8dad26b1e6b700218bb5aac048c88590ee1b085aebf0.png

Solving overfitting in CNNs#

  • There are various ways to further improve the model:

    • Generating more training data (data augmentation)

    • Regularization (e.g. Dropout, L1/L2, Batch Normalization,…)

    • Use pretrained rather than randomly initialized filters

      • These are trained on a lot more data

Data augmentation#

  • Generate new images via image transformations (only on training data!)

    • Images will be randomly transformed every epoch

  • Update the transform in the data module

self.train_transform = transforms.Compose([
    transforms.Resize(self.img_size), # Resize to 150x150
    transforms.RandomRotation(40),    # Rotations up to 40 degrees
    transforms.RandomResizedCrop(self.img_size, 
                                 scale=(0.8, 1.2)), # Scale + crop, up to 20%
    transforms.RandomHorizontalFlip(),              # Horizontal flip
    transforms.RandomAffine(degrees=0, shear=20),   # Shear, up to 20%
    transforms.ColorJitter(brightness=0.2, contrast=0.2, 
                           saturation=0.2),         # Color jitter
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
Hide code cell source
class CatDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=20, img_size=(150, 150)):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.img_size = img_size

        # Training Data Augmentation 
        self.train_transform = transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.RandomRotation(40),
            transforms.RandomResizedCrop(self.img_size, scale=(0.8, 1.2)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomAffine(degrees=0, shear=20),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        # Test Data Transforms (NO augmentation, just resize + normalize)
        self.val_transform = transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def setup(self, stage=None):
        """Load datasets with correct transforms"""
        train_dir = os.path.join(self.data_dir, "train")
        val_dir = os.path.join(self.data_dir, "validation")

        # Apply augmentation only to training data
        self.train_dataset = datasets.ImageFolder(root=train_dir, transform=self.train_transform)
        self.val_dataset = datasets.ImageFolder(root=val_dir, transform=self.val_transform)

    def train_dataloader(self):
        """Applies augmentation via the pre-defined transform"""
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        """Loads validation data WITHOUT augmentation"""
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)

Augmentation example

def show_augmented_images(data_module, num_images=8):
    """Visualize the same image with different random augmentations."""
    
    train_dataset = data_module.train_dataset  # Get training dataset with augmentation
    
    # Select a random image (without augmentation)
    idx = np.random.randint(len(train_dataset))
    original_img, label = train_dataset[idx]  # This is already augmented

    # Convert original image back to NumPy format
    original_img_np = original_img.permute(1, 2, 0).numpy()  # Convert (C, H, W) → (H, W, C)
    original_img_np = (original_img_np - original_img_np.min()) / (original_img_np.max() - original_img_np.min())  # Normalize

    fig, axes = plt.subplots(2, 4, figsize=(10, 5))  # Create 4x2 grid
    axes = axes.flatten()

    for i in range(num_images):
        # Apply new augmentation on the same image each time
        img, _ = train_dataset[idx]  # Re-fetch the same image, but with a new random augmentation
        
        # Convert tensor image back to NumPy format
        img = img.permute(1, 2, 0).numpy()  # Convert (C, H, W) → (H, W, C)
        img = (img - img.min()) / (img.max() - img.min())  # Normalize

        # Plot the augmented image
        axes[i].imshow(img)
        axes[i].set_xticks([])
        axes[i].set_yticks([])

    plt.tight_layout()
    plt.show()

# Load dataset and visualize augmented images
data_module = CatDataModule(data_dir)  # Set correct dataset path
data_module.setup()
show_augmented_images(data_module)
../_images/e90fb8a88b6721e45a5c9b3d4f1bc8c3feed009e0f8e193a923162c9e8d8216a.png

We also add Dropout before the Dense layer, and L2 regularization (‘weight decay’) in Adam

class CatImageClassifier(pl.LightningModule):
    def __init__(self, learning_rate=0.001):
        super().__init__()
        self.save_hyperparameters()

        # Define convolutional layers (CNN)
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(32, 64, kernel_size=3, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 128, kernel_size=3, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.AdaptiveAvgPool2d(1)  # GAP instead of Flatten
        )

        # Fully connected layers (FC) with Dropout
        self.fc_layers = nn.Sequential(
            nn.Linear(128, 512),  # GAP outputs (batch, 128, 1, 1) → Flatten to (batch, 128)
            nn.ReLU(),
            nn.Dropout(0.5),  # Dropout (same as Keras Dropout(0.5))
            nn.Linear(512, 1)  # Binary classification (1 output neuron)
        )

        self.loss_fn = nn.BCEWithLogitsLoss()
        self.accuracy = Accuracy(task="binary")

    def forward(self, x):
        x = self.conv_layers(x)  # Convolutions + GAP
        x = x.view(x.size(0), -1)  # Flatten from (batch, 128, 1, 1) → (batch, 128)
        x = self.fc_layers(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x).squeeze(1)  # Remove extra dimension
        loss = self.loss_fn(logits, y.float())  # BCE loss requires float labels

        preds = torch.sigmoid(logits)  # Convert logits to probabilities
        acc = self.accuracy(preds, y)

        self.log("train_loss", loss, prog_bar=True, on_epoch=True)
        self.log("train_acc", acc, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x).squeeze(1)
        loss = self.loss_fn(logits, y.float())

        preds = torch.sigmoid(logits)
        acc = self.accuracy(preds, y)

        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        self.log("val_acc", acc, prog_bar=True, on_epoch=True)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-4)
    
model = CatImageClassifier()
summary(model, input_size=(1, 3, 150, 150))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
CatImageClassifier                       [1, 1]                    --
├─Sequential: 1-1                        [1, 128, 1, 1]            --
│    └─Conv2d: 2-1                       [1, 32, 148, 148]         896
│    └─ReLU: 2-2                         [1, 32, 148, 148]         --
│    └─MaxPool2d: 2-3                    [1, 32, 74, 74]           --
│    └─Conv2d: 2-4                       [1, 64, 72, 72]           18,496
│    └─ReLU: 2-5                         [1, 64, 72, 72]           --
│    └─MaxPool2d: 2-6                    [1, 64, 36, 36]           --
│    └─Conv2d: 2-7                       [1, 128, 34, 34]          73,856
│    └─ReLU: 2-8                         [1, 128, 34, 34]          --
│    └─MaxPool2d: 2-9                    [1, 128, 17, 17]          --
│    └─Conv2d: 2-10                      [1, 128, 15, 15]          147,584
│    └─ReLU: 2-11                        [1, 128, 15, 15]          --
│    └─MaxPool2d: 2-12                   [1, 128, 7, 7]            --
│    └─AdaptiveAvgPool2d: 2-13           [1, 128, 1, 1]            --
├─Sequential: 1-2                        [1, 1]                    --
│    └─Linear: 2-14                      [1, 512]                  66,048
│    └─ReLU: 2-15                        [1, 512]                  --
│    └─Dropout: 2-16                     [1, 512]                  --
│    └─Linear: 2-17                      [1, 1]                    513
==========================================================================================
Total params: 307,393
Trainable params: 307,393
Non-trainable params: 0
Total mult-adds (M): 234.16
==========================================================================================
Input size (MB): 0.27
Forward/backward pass size (MB): 9.68
Params size (MB): 1.23
Estimated Total Size (MB): 11.18
==========================================================================================

No more overfitting!

Hide code cell source
pl.seed_everything(42)  # Ensure reproducibility
data_module = CatDataModule(data_dir, batch_size=64)
model = CatImageClassifier(learning_rate=0.001)
metric_tracker = MetricTracker()  # Callback to track per-epoch metrics

trainer = pl.Trainer(
    max_epochs=50,  # Train for 20 epochs
    accelerator=accelerator,
    devices="auto",
    log_every_n_steps=10,
    deterministic=True,
    callbacks=[metric_tracker, checkpoint_callback]  # Attach callback to trainer
)

# If previously trained, load history and weights
if histories and histories["cat2"]:
    history_cat2 = histories["cat2"]
    model = CatImageClassifier.load_from_checkpoint("../data/checkpoints/cat_model.ckpt")
else:
    trainer.fit(model, datamodule=data_module)
    history_cat2 = metric_tracker.history
    
# Set to evaluation mode so we don't update the weights
model.eval()
Seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
# history_cat2 = histories["cat2"]
# model = CatImageClassifier.load_from_checkpoint("../data/checkpoints/cat_model.ckpt")
plot_training(history_cat2)
../_images/f387b643aa497b74d53a03093022b69af7edff151b5ea1808e58c1123f5576be.png
import pickle 

histories = {"mnist":history,"cat":history_cat,"cat2":history_cat2}
with open("../data/histories.pkl", "wb") as f:
    pickle.dump(histories, f)

Real-world CNNs#

VGG16#

  • Deeper architecture (16 layers): allows it to learn more complex high-level features

    • Textures, patterns, shapes,…

  • Small filters (3x3) work better: capture spatial information while reducing number of parameters

  • Max-pooling (2x2): reduces spatial dimension, improves translation invariance

    • Lower resolution forces model to learn robust features (less sensitive to small input changes)

    • Only after every 2 layers, otherwise dimensions reduce too fast

  • Downside: too many parameters, expensive to train

ml

Inceptionv3#

  • Inception modules: parallel branches learn features of different sizes and scales (3x3, 5x5, 7x7,…)

    • Add reduction blocks that reduce dimensionality via convolutions with stride 2

  • Factorized convolutions: a 3x3 conv. can be replaced by combining 1x3 and 3x1, and is 33% cheaper

    • A 5x5 can be replaced by combining 3x3 and 3x3, which can in turn be factorized as above

  • 1x1 convolutions, or Network-In-Network (NIN) layers help reduce the number of channels: cheaper

  • An auxiliary classifier adds an additional gradient signal deeper in the network

ml

Factorized convolutions#

  • A 3x3 conv. can be replaced by combining 1x3 and 3x1, and is 33% cheaper

ml

ResNet50#

  • Residual (skip) connections: add earlier feature map to a later one (dimensions must match)

    • Information can bypass layers, reduces vanishing gradients, allows much deeper nets

  • Residual blocks: skip small number or layers and repeat many times

    • Match dimensions though padding and 1x1 convolutions

    • When resolution drops, add 1x1 convolutions with stride 2

  • Can be combined with Inception blocks

ml

Interpreting the model#

  • Let’s see what the convnet is learning exactly by observing the intermediate feature maps

  • We can do this easily by attaching a ‘hook’ to a layer so we can read it’s output (activation)

# Create a hook to extract intermediate output (activation)
def hook_fn(module, input, output): 
    nonlocal activation
    activation = output.detach()
    
# Add a hook to a specific layer
hook = model.features[layer_id].register_forward_hook(hook_fn)
# Do a forward pass without gradient computation
with torch.no_grad(): 
    model(image_tensor) 
    
return activation

Result for a specific filter (Layer 0, Filter 0)

from PIL import Image
import os

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

def load_image(img_path, img_size=(150, 150)):
    """Load and preprocess image as a PyTorch tensor."""
    transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),  # Converts image to tensor with values in [0,1]
    ])
    
    img = Image.open(img_path).convert("RGB")  # Ensure RGB format
    img_tensor = transform(img).unsqueeze(0)  # Add batch dimension
    return img_tensor

def get_layer_activations(model, img_tensor, layer_idx=0, keep_gradients=False):
    """Extract activations from a specific layer."""
    activation = None
    
    def hook_fn(module, input, output):
        nonlocal activation
        if keep_gradients: # Only for gradient ascent (later)
            activation = output
        else:
            activation = output.detach()

    # Register hook to capture the activation
    # Handles our custom model and more general models like VGG
    layer = model.conv_layers[layer_idx] if hasattr(model, "conv_layers") else model[layer_idx]
    hook = layer.register_forward_hook(hook_fn)    
    
    if keep_gradients:
        model(img_tensor)  # Run the image through the model
    else:
        with torch.no_grad():
            model(img_tensor)  # Idem but no grad
    
    hook.remove()  # Remove the hook after getting activations
    return activation

def visualize_activations(model, img_tensor, layer_idx=0, filter_idx=0):
    """Visualize input image and activations of a selected filter."""

    # Get activations from the specified layer
    activations = get_layer_activations(model, img_tensor, layer_idx)

    # Convert activations to numpy for visualization
    activation_np = activations.squeeze(0).cpu().numpy()  # Remove batch dim
    
    # Show input image
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(4, 2))
    
    # Convert input tensor to NumPy
    img_np = img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()  # (H, W, C)
    img_np = np.clip(img_np, 0, 1)  # Ensure values are in range [0,1]
    
    ax1.imshow(img_np)
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.set_xlabel("Input Image", fontsize=8)
    
    # Visualize a specific filter's activation
    ax2.imshow(activation_np[filter_idx], cmap="viridis")
    ax2.set_xticks([])
    ax2.set_yticks([])
    ax2.set_xlabel(f"Activation of Filter {filter_idx}", fontsize=8)
    
    plt.tight_layout()
    plt.show()
# Load model and visualize activations
img_path = os.path.join(data_dir, "train/cats/cat.1700.jpg")  # Update path
img_tensor = load_image(img_path).to(accelerator)

visualize_activations(model, img_tensor, layer_idx=0, filter_idx=0)
../_images/f1f30915335c38447c62622d6a17e82c2177acbc762f137ac47f8758464bd756.png

The same filter responds quite differently for other inputs.

Hide code cell source
img_path_dog = os.path.join(data_dir, "train/dogs/dog.1528.jpg")
img_tensor_dog = load_image(img_path_dog).to(accelerator)

visualize_activations(model, img_tensor_dog, layer_idx=0, filter_idx=0)
../_images/9a654205f561f68a0f9ea39235fc502145bcf286aaefab24c33b2d782439e493.png
Hide code cell source
def visualize_all_filters(model, img_tensor, layer_idx=0, max_per_row=16):
    """Visualize all filters of a given layer as a grid of feature maps."""
    activations = get_layer_activations(model, img_tensor, layer_idx)
    activation_np = activations.squeeze(0).cpu().numpy()
    
    num_filters = activation_np.shape[0]
    num_cols = min(num_filters, max_per_row)
    num_rows = (num_filters + num_cols - 1) // num_cols  # Ceiling division
    
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols, num_rows))
    axes = np.array(axes).reshape(num_rows, num_cols)  # Ensure it's a 2D array
    
    for i in range(num_rows * num_cols):
        ax = axes[i // num_cols, i % num_cols]
        
        if i < num_filters:
            ax.imshow(activation_np[i], cmap="viridis")
        
        ax.set_xticks([])
        ax.set_yticks([])
    
    plt.suptitle(f"Activations of Layer {layer_idx}", fontsize=16, y=1.0)
    plt.tight_layout()
    plt.show()

Let’s plot all the activations to see what happens in the network

  • First 2 convolutional layers: various edge detectors

  • Empty filter activations occur:

    • Filter is not interested in that input image (maybe it’s dog-specific)

    • Incomplete training, Dying ReLU,…

Hide code cell source
visualize_all_filters(model, img_tensor, layer_idx=0)
visualize_all_filters(model, img_tensor, layer_idx=3)
../_images/8903c4a7bb5b139c6d132652dfd826bb62594cc97d92d14996e9fbf13c83eb2c.png ../_images/5e1edd38ed0261656e49dc47fe5e55db901e61c83755e66fc29d6db4dee6d9ac.png
  • 3rd convolutional layer: increasingly abstract: ears, nose, eyes

Hide code cell source
visualize_all_filters(model, img_tensor, layer_idx=6)
../_images/bdc9c87394b0a3c690c8f4741fdc2e0ba579bc5a58a1bd3cdcfa5999491a17fd.png
  • Last convolutional layer: more abstract patterns

  • Increasingly combine information from all previous filters

Hide code cell source
visualize_all_filters(model, img_tensor, layer_idx=9)
../_images/44c281af2e9f7311eb1d04259362fd1aaae8cd7c1bd9591ea54387d7c425d379.png
  • Same layer, with dog image input

    • Some filters react only to dogs (or cats)

    • Deeper layers learn representations that separate the classes

Hide code cell source
visualize_all_filters(model, img_tensor_dog, layer_idx=9)
../_images/3ed4b85a9ac333cb5f6236ddf031ab321e15c5c3119b67c9546872e3ef1e13e3.png

Spatial hierarchies#

  • Deep convnets can learn spatial hierarchies of patterns

    • First layer can learn very local patterns (e.g. edges)

    • Second layer can learn specific combinations of patterns

    • Every layer can learn increasingly complex abstractions

ml

Visualizing the learned filters#

  • Visualize filters by finding the input image that they are maximally responsive to

  • gradient ascent in input space : start from a random image \(x\), use loss to update the input values to values that the filter responds to more strongly (keep weights fixed)

    • \(X_{(i+1)} = X_{(i)} + \frac{\partial L(x, X_{(i)})}{\partial X} * \eta\)

    # Create a random input tensor and tell Adam to optimize the pixels
    input_img = torch.randn((1, 3, size, size), requires_grad=True)
    optimizer = torch.optim.Adam([input_img], lr=step_size)

    # Get activations (same as before) of this input for a specific filter
    # Compute gradient to update (maximize) the activations for this filter 
    for _ in range(steps):
        optimizer.zero_grad() 
        activations = get_layer_activations(model, input_img, layer_idx)
        loss = activations[0, filter_idx].mean()  # Maximize activation of this filter
        loss.backward()
        optimizer.step()
Hide code cell source
import torch.nn.functional as F
from torchvision import transforms

def deprocess_image(x):
    """Normalize and convert tensor to a displayable image."""
    x -= x.mean()
    x /= (x.std() + 1e-5)
    x *= 0.1
    x += 0.5
    x = np.clip(x, 0, 1)
    x *= 255
    return np.clip(x, 0, 255).astype('uint8')

def generate_pattern(model, layer_idx, filter_idx, size=56, steps=40, step_size=1.0):
    """Perform gradient ascent to generate a visualization of the filter."""
    model.eval()
    #device = next(model.parameters()).device
    
    # Start with a random image
    input_img = torch.randn((1, 3, size, size), requires_grad=True, device=accelerator)
    optimizer = torch.optim.Adam([input_img], lr=step_size)
    
    for _ in range(steps):
        optimizer.zero_grad()
        activations = get_layer_activations(model, input_img, layer_idx, keep_gradients=True)
        loss = activations[0, filter_idx].mean()  # Maximize activation of this filter
        loss.backward()
        optimizer.step()
    
    img = input_img.detach().squeeze().permute(1, 2, 0).cpu().numpy()
    return deprocess_image(img)

def visualize_filters(model, layer_idx, grid_size=(8, 8), size=64, margin=5):
    """Visualize all filters of a layer in a grid using gradient ascent."""
    num_filters = grid_size[0] * grid_size[1]
    results = np.zeros((grid_size[0] * size + (grid_size[0] - 1) * margin, 
                        grid_size[1] * size + (grid_size[1] - 1) * margin, 3), dtype=np.uint8)
    
    for i in range(grid_size[0]):
        for j in range(grid_size[1]):
            filter_idx = i * grid_size[1] + j
            filter_img = generate_pattern(model, layer_idx, filter_idx, size=size)
            
            h_start = i * size + i * margin
            h_end = h_start + size
            v_start = j * size + j * margin
            v_end = v_start + size
            
            results[h_start:h_end, v_start:v_end, :] = filter_img
    
    plt.figure(figsize=(4, 4))
    plt.imshow(results)
    plt.axis('off')
    plt.show()
  • Learned filters of second convolutional layer

  • Mostly general, some respond to specific shapes/colors

Hide code cell source
visualize_filters(model, layer_idx=3)

The next layer starts showing some structures / textures

visualize_filters(model, layer_idx=6)

Last convolutional layer

  • More focused on center, some vague cat/dog head shapes

ml

We need to go deeper. Bigger networks, more data!#

Let’s do this again for the VGG16 network pretrained on ImageNet

from torchvision.models import vgg16
vgg16_model = vgg16(pretrained=True) # Load with pretrained weights
vgg16_model = vgg16_model.features   # Keep only the features (not the head)
Hide code cell source
from torchvision.models import vgg16

# Load VGG16 pretrained on ImageNet
vgg16_model = vgg16(pretrained=True)

# Remove the fully connected layers (equivalent to include_top=False in Keras)
vgg16_model_feat = vgg16_model.features

# Set model to evaluation mode and move to GPU
vgg16_model_feat.eval();
# Print model summary
summary(vgg16_model_feat, input_size=(1, 3, 224, 224))  # Pretraining images where 224x224
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Sequential                               [1, 512, 7, 7]            --
├─Conv2d: 1-1                            [1, 64, 224, 224]         1,792
├─ReLU: 1-2                              [1, 64, 224, 224]         --
├─Conv2d: 1-3                            [1, 64, 224, 224]         36,928
├─ReLU: 1-4                              [1, 64, 224, 224]         --
├─MaxPool2d: 1-5                         [1, 64, 112, 112]         --
├─Conv2d: 1-6                            [1, 128, 112, 112]        73,856
├─ReLU: 1-7                              [1, 128, 112, 112]        --
├─Conv2d: 1-8                            [1, 128, 112, 112]        147,584
├─ReLU: 1-9                              [1, 128, 112, 112]        --
├─MaxPool2d: 1-10                        [1, 128, 56, 56]          --
├─Conv2d: 1-11                           [1, 256, 56, 56]          295,168
├─ReLU: 1-12                             [1, 256, 56, 56]          --
├─Conv2d: 1-13                           [1, 256, 56, 56]          590,080
├─ReLU: 1-14                             [1, 256, 56, 56]          --
├─Conv2d: 1-15                           [1, 256, 56, 56]          590,080
├─ReLU: 1-16                             [1, 256, 56, 56]          --
├─MaxPool2d: 1-17                        [1, 256, 28, 28]          --
├─Conv2d: 1-18                           [1, 512, 28, 28]          1,180,160
├─ReLU: 1-19                             [1, 512, 28, 28]          --
├─Conv2d: 1-20                           [1, 512, 28, 28]          2,359,808
├─ReLU: 1-21                             [1, 512, 28, 28]          --
├─Conv2d: 1-22                           [1, 512, 28, 28]          2,359,808
├─ReLU: 1-23                             [1, 512, 28, 28]          --
├─MaxPool2d: 1-24                        [1, 512, 14, 14]          --
├─Conv2d: 1-25                           [1, 512, 14, 14]          2,359,808
├─ReLU: 1-26                             [1, 512, 14, 14]          --
├─Conv2d: 1-27                           [1, 512, 14, 14]          2,359,808
├─ReLU: 1-28                             [1, 512, 14, 14]          --
├─Conv2d: 1-29                           [1, 512, 14, 14]          2,359,808
├─ReLU: 1-30                             [1, 512, 14, 14]          --
├─MaxPool2d: 1-31                        [1, 512, 7, 7]            --
==========================================================================================
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0
Total mult-adds (G): 15.36
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 108.38
Params size (MB): 58.86
Estimated Total Size (MB): 167.84
==========================================================================================
# Retry, with a custom class this time
class SaveFeatures():
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
    
    def hook_fn(self, module, input, output):
        self.features = output
    
    def close(self):
        self.hook.remove()

class FilterVisualizer():
    def __init__(self, size=56, upscaling_steps=12, upscaling_factor=1.2, device=None):
        self.size = size
        self.upscaling_steps = upscaling_steps
        self.upscaling_factor = upscaling_factor
        self.device = accelerator
        self.model = models.vgg16(pretrained=True).features.to(self.device).eval()

    def visualize(self, layer, filter, lr=0.1, opt_steps=20, blur=None):
        sz = self.size
        img = np.uint8(np.random.uniform(150, 180, (sz, sz, 3))) / 255  # generate random image
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        
        activations = SaveFeatures(self.model[layer])  # register hook

        for _ in range(self.upscaling_steps):  # scale the image up
            img_var = torch.tensor(transform(Image.fromarray((img * 255).astype(np.uint8)))[None],
                                   dtype=torch.float32, requires_grad=True, device=self.device)
            optimizer = optim.Adam([img_var], lr=lr, weight_decay=1e-6)
            
            for _ in range(opt_steps):  # optimize pixel values
                optimizer.zero_grad()
                self.model(img_var)
                loss = -activations.features[0, filter].mean()
                loss.backward()
                optimizer.step()
            
            img = img_var.detach().cpu().numpy()[0].transpose(1, 2, 0)
            self.output = img
            sz = int(self.upscaling_factor * sz)  # calculate new image size
            img = cv2.resize(img, (sz, sz), interpolation=cv2.INTER_CUBIC)  # scale image up
            if blur is not None:
                img = cv2.blur(img, (blur, blur))  # blur image
        
        activations.close()
        return self.output

    def visualize_filters(self, layer, num_filters=None, blur=None, filters=None):
        activations = SaveFeatures(self.model[layer])
        num_filters = num_filters or activations.features.shape[1]  # Default to all filters in layer
        filter_images = []

        if filters:
            num_filters = len(filters)
            for filter_idx in filters:
                img = self.visualize(layer, filter_idx, blur=blur)
                filter_images.append(img)
        else:            
            for filter_idx in range(num_filters):
                img = self.visualize(layer, filter_idx, blur=blur)
                filter_images.append(img)

        self.show_filters(filter_images, num_filters)
        activations.close()

    def show_filters(self, filter_images, num_filters):
        cols = min(8, num_filters)  # Limit to max 8 columns
        rows = (num_filters // cols) + int(num_filters % cols > 0)

        # Dynamically adjust figure size
        fig, axes = plt.subplots(rows, cols, figsize=(cols, rows))
        axes = np.array(axes).flatten()  # Flatten in case of single row/col

        for i, img in enumerate(filter_images):
            axes[i].imshow(np.clip(img, 0, 1))
            axes[i].axis('off')

        # Remove empty subplots
        for i in range(len(filter_images), len(axes)):
            fig.delaxes(axes[i])  # Delete extra empty axes

        fig.subplots_adjust(wspace=0, hspace=0)  # Remove spacing between subplots
        plt.tight_layout(pad=0)  # Remove extra padding around the figure
        plt.show()

VGG Layer 5

filters=[12, 16, 86, 110]
FV = FilterVisualizer(size=56, upscaling_steps=12, upscaling_factor=1.2)
FV.visualize_filters(5, filters=filters, blur=5)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[442], line 3
      1 filters=[12, 16, 86, 110]
      2 FV = FilterVisualizer(size=56, upscaling_steps=12, upscaling_factor=1.2)
----> 3 FV.visualize_filters(5, filters=filters, blur=5)

Cell In[441], line 53, in FilterVisualizer.visualize_filters(self, layer, num_filters, blur, filters)
     51 def visualize_filters(self, layer, num_filters=None, blur=None, filters=None):
     52     activations = SaveFeatures(self.model[layer])
---> 53     num_filters = num_filters or activations.features.shape[1]  # Default to all filters in layer
     54     filter_images = []
     56     if filters:

AttributeError: 'SaveFeatures' object has no attribute 'features'
FV = FilterVisualizer(size=56, upscaling_steps=12, upscaling_factor=1.2)
FV.visualize_filters(5, num_filters=32, blur=5)

VGG Layer 10

FV = FilterVisualizer(size=56, upscaling_steps=12, upscaling_factor=1.2)
FV.visualize_filters(10, num_filters=32, blur=5)

VGG Layer 17

FV = FilterVisualizer(size=56, upscaling_steps=12, upscaling_factor=1.2)
FV.visualize_filters(17, num_filters=32, blur=5)

VGG Layer 24

FV = FilterVisualizer(size=56, upscaling_steps=12, upscaling_factor=1.2)
FV.visualize_filters(24, num_filters=32, blur=5)

VGG Layer 28

FV = FilterVisualizer(size=56, upscaling_steps=12, upscaling_factor=1.2)
FV.visualize_filters(28, num_filters=32, blur=5)

Visualizing class activation#

  • We can also visualize which part of the input image had the greatest influence on the final classification. Helps to interpret what the model is paying attention to.

  • Class activation maps : produces a heatmap over the input image

    • Choose a convolution layer, do Global Average Pooling (GAP) to get one output per filter

    • Get the weights between those outputs and the class of interest

    • Compute the weighted sum of all filter activations: combines what each filter is responding to and how much this affects the class prediction

ml

Implementing gradCAM#

    target_layer = model.layer4[-1] # Last conv layer
    
    # Hooks to capture activations and gradients
    def forward_hook(module, input, output):
        activations = output

    def backward_hook(module, grad_input, grad_output):
        gradients = grad_output[0]
    # Register hooks
    target_layer.register_forward_hook(forward_hook)
    target_layer.register_full_backward_hook(backward_hook)

    # Forward pass
    output = model(img_tensor)
    pred_class = output.argmax(dim=1).item()
    # Backward pass to compute gradients
    model.zero_grad()
    output[:, pred_class].backward()

    # Compute Grad-CAM heatmap
    weights = torch.mean(gradients, dim=[2, 3], keepdim=True)  # GAP layer
    heatmap = torch.sum(weights * activations, dim=1).squeeze()

Example on ResNet50 with a specific input image (class Elephant)

import torch
import torch.nn.functional as F
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torchvision import models, transforms


def gradCAM(img_path):
    # Load a pre-trained model (e.g., ResNet50)
    model = models.resnet50(pretrained=True)
    model.eval()
    
    # Define the target convolutional layer for Grad-CAM (last conv layer of ResNet)
    target_layer = model.layer4[-1]

    # Define the preprocessing transformation
    preprocess = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Load and preprocess the image
    original_img = cv2.imread(img_path)
    original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
    img_tensor = preprocess(original_img).unsqueeze(0)  # Add batch dimension

    # Store activations and gradients inside function scope
    activations = None
    gradients = None

    # Hooks to capture activations and gradients
    def forward_hook(module, input, output):
        nonlocal activations  # Ensure we modify the function-scoped variable
        activations = output

    def backward_hook(module, grad_input, grad_output):
        nonlocal gradients  # Ensure we modify the function-scoped variable
        gradients = grad_output[0]

    # Register hooks
    target_layer.register_forward_hook(forward_hook)
    target_layer.register_full_backward_hook(backward_hook)

    # Forward pass
    output = model(img_tensor)
    pred_class = output.argmax(dim=1).item()

    # Backward pass to compute gradients
    model.zero_grad()
    output[:, pred_class].backward()

    # Ensure hooks stored values
    if activations is None or gradients is None:
        raise RuntimeError("Hooks did not capture activations or gradients.")

    # Compute Grad-CAM heatmap
    weights = torch.mean(gradients, dim=[2, 3], keepdim=True)  # Global Average Pooling
    heatmap = torch.sum(weights * activations, dim=1).squeeze()
    heatmap = heatmap.cpu().detach().numpy()

    # Normalize heatmap
    heatmap = np.maximum(heatmap, 0)  # ReLU (only positive values)
    heatmap /= (np.max(heatmap) + 1e-10)  # Normalize to [0,1] to avoid division by zero

    # Resize heatmap to match original image size
    heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))

    # Convert heatmap to colormap
    heatmap = np.uint8(255 * heatmap)  # Scale to 0-255
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # Apply colormap

    # Overlay heatmap on original image
    superimposed_img = cv2.addWeighted(original_img, 0.6, heatmap, 0.4, 0)

    # Display result
    plt.figure(figsize=(8, 8))
    plt.imshow(superimposed_img)
    plt.axis("off")
    plt.title("Grad-CAM Heatmap Overlay", fontsize=12)
    plt.show()
    
img_path = '../notebooks/images/10_elephants.jpg'
gradCAM(img_path)
../_images/d6d294ac9dff11d666339feec06a14c0304e6dfc3248368e16794b263cee1153.png

Transfer learning#

  • We can re-use pretrained networks instead of training from scratch

  • Learned features can be a generic model of the visual world

  • Use convolutional base to contruct features, then train any classifier on new data

  • Also called transfer learning , which is a kind of meta-learning

ml
  • Let’s instantiate the VGG16 model (without the dense layers)

  • The final feature map has shape (7, 7, 512)

vgg16_model = vgg16(pretrained=True)
conv_base = vgg16_model.features
Hide code cell source
from torchvision.models import vgg16

# Load VGG16 pretrained on ImageNet
vgg16_model = vgg16(pretrained=True)

# Remove the fully connected layers
conv_base = vgg16_model.features

# Set model to evaluation mode and move to GPU
conv_base.eval();

Using pre-trained networks: 3 ways#

  • Fast feature extraction (for similar task, little data)

    • Call predict from the convolutional base to build new features

    • Use outputs as input to a new neural net (or other algorithm)

  • End-to-end tuning (for similar task, lots of data + data augmentation)

    • Extend the convolutional base model with a new dense layer

    • Train it end to end on the new data (expensive!)

  • Fine-tuning (for somewhat different task)

    • Unfreeze a few of the top convolutional layers, and retrain

    • Update only the more abstract representations

ml

Fast feature extraction (without data augmentation)#

  • Build Dense neural net (with Dropout)

  • Run every batch through the pre-trained (frozen) convolutional base

class ImageClassifier(pl.LightningModule):
    def __init__(self, conv_base, input_dim):
        super().__init__()
        self.conv_base = conv_base  # Keep conv layers
        self.conv_base.eval()  # Set to eval mode
        for param in self.conv_base.parameters():
            param.requires_grad = False  # Freeze convolutional layers
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 1),
        )
        self.criterion = nn.BCELoss()
    def forward(self, x):
        with torch.no_grad():  # Prevent gradients in conv layers
            features = self.conv_base(x)
            features = features.view(features.size(0), -1)  # Flatten
        return self.classifier(features)

Create and train the model

# Get pretrained network
conv_base = vgg16(pretrained=True).features.to(accelerator)
input_dim = 512 * 7 * 7  # Feature map size from VGG16

# Initialize DataModule and Model
model = ImageClassifier(conv_base, input_dim)
data_module = ImageDataModule(train_dir, val_dir)
# Train
trainer = pl.Trainer(
    max_epochs=EPOCHS,
    accelerator=accelerator,
)
trainer.fit(model, data_module)
Hide code cell source
# Hyperparameters
BATCH_SIZE = 32
EPOCHS = 30
LEARNING_RATE = 1e-4
IMAGE_SIZE = (150, 150)

# Define DataModule
class ImageDataModule(pl.LightningDataModule):
    def __init__(self, train_dir, val_dir, batch_size=BATCH_SIZE):
        super().__init__()
        self.train_dir = train_dir
        self.val_dir = val_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def setup(self, stage=None):
        self.train_dataset = datasets.ImageFolder(self.train_dir, transform=self.transform)
        self.val_dataset = datasets.ImageFolder(self.val_dir, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)


# Define Model
class ImageClassifier(pl.LightningModule):
    def __init__(self, conv_base, input_dim):
        super().__init__()
        self.conv_base = conv_base  # Keep conv layers
        self.conv_base.eval()  # Set to eval mode (but still in model)
        for param in self.conv_base.parameters():
            param.requires_grad = False  # Freeze convolutional layers

        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        self.criterion = nn.BCELoss()

    def forward(self, x):
        with torch.no_grad():  # Prevent gradients in conv layers
            features = self.conv_base(x)
            features = features.view(features.size(0), -1)  # Flatten
        return self.classifier(features)

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)  # Forward pass includes feature extraction
        loss = self.criterion(outputs, targets)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, targets)
        predicted = (outputs > 0.5).float()
        accuracy = (predicted == targets).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", accuracy, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.classifier.parameters(), lr=LEARNING_RATE)

train_dir = "/path/to/train"
val_dir = "/path/to/validation"

# Initialize DataModule and Model
conv_base = vgg16(pretrained=True).features.to(accelerator)
input_dim = 512 * 7 * 7  # Feature map size from VGG16

model = ImageClassifier(conv_base, input_dim)
data_module = ImageDataModule(train_dir, val_dir)

trainer = pl.Trainer(
    max_epochs=EPOCHS,
    accelerator=accelerator,
)

trainer.fit(model, data_module)

Fast feature extraction (with data augmentation)#

  • Simply add the Dense layers to the convolutional base

  • Freeze the convolutional base

  • Add data augmentation to the data module

class FFEClassifier(nn.Module):
    def __init__(self, conv_base):
        super(CNNClassifier, self).__init__()
        self.conv_base = conv_base  # Pretrained feature extractor
        self.flatten = nn.Flatten()  # Flatten feature maps
        self.fc = nn.Sequential(
            nn.Linear(512 * 4 * 4, 256),  # Adjust input size based on feature maps
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()  # Outputs probability (0 to 1)
        )
    def forward(self, x):
        with torch.no_grad():  # Ensure no gradients for the conv_base
            x = self.conv_base(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x
    
for param in conv_base.parameters(): # Freeze
    param.requires_grad = False
model = FFEClassifier(conv_base)
trainer.fit(model, data_module)
Hide code cell source
# Load pre-trained VGG16 model
conv_base = models.vgg16(pretrained=True).features  # Only convolutional layers

# Freeze convolutional base
for param in conv_base.parameters():
    param.requires_grad = False

# Define PyTorch model
class FFEClassifier(nn.Module):
    def __init__(self, conv_base):
        super(FFEClassifier, self).__init__()
        self.conv_base = conv_base  # Pretrained feature extractor
        self.flatten = nn.Flatten()  # Flatten feature maps
        self.fc = nn.Sequential(
            nn.Linear(512 * 4 * 4, 256),  # Adjust input size based on feature maps
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()  # Outputs probability (0 to 1)
        )

    def forward(self, x):
        with torch.no_grad():  # Ensure no gradients for the conv_base
            x = self.conv_base(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

# Initialize model
model = FFEClassifier(conv_base)

Architecture

summary(model, input_size=(1, 3, 150, 150))  # Correct input shape for VGG16
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
FFEClassifier                            [1, 1]                    --
├─Sequential: 1-1                        [1, 512, 4, 4]            --
│    └─Conv2d: 2-1                       [1, 64, 150, 150]         (1,792)
│    └─ReLU: 2-2                         [1, 64, 150, 150]         --
│    └─Conv2d: 2-3                       [1, 64, 150, 150]         (36,928)
│    └─ReLU: 2-4                         [1, 64, 150, 150]         --
│    └─MaxPool2d: 2-5                    [1, 64, 75, 75]           --
│    └─Conv2d: 2-6                       [1, 128, 75, 75]          (73,856)
│    └─ReLU: 2-7                         [1, 128, 75, 75]          --
│    └─Conv2d: 2-8                       [1, 128, 75, 75]          (147,584)
│    └─ReLU: 2-9                         [1, 128, 75, 75]          --
│    └─MaxPool2d: 2-10                   [1, 128, 37, 37]          --
│    └─Conv2d: 2-11                      [1, 256, 37, 37]          (295,168)
│    └─ReLU: 2-12                        [1, 256, 37, 37]          --
│    └─Conv2d: 2-13                      [1, 256, 37, 37]          (590,080)
│    └─ReLU: 2-14                        [1, 256, 37, 37]          --
│    └─Conv2d: 2-15                      [1, 256, 37, 37]          (590,080)
│    └─ReLU: 2-16                        [1, 256, 37, 37]          --
│    └─MaxPool2d: 2-17                   [1, 256, 18, 18]          --
│    └─Conv2d: 2-18                      [1, 512, 18, 18]          (1,180,160)
│    └─ReLU: 2-19                        [1, 512, 18, 18]          --
│    └─Conv2d: 2-20                      [1, 512, 18, 18]          (2,359,808)
│    └─ReLU: 2-21                        [1, 512, 18, 18]          --
│    └─Conv2d: 2-22                      [1, 512, 18, 18]          (2,359,808)
│    └─ReLU: 2-23                        [1, 512, 18, 18]          --
│    └─MaxPool2d: 2-24                   [1, 512, 9, 9]            --
│    └─Conv2d: 2-25                      [1, 512, 9, 9]            (2,359,808)
│    └─ReLU: 2-26                        [1, 512, 9, 9]            --
│    └─Conv2d: 2-27                      [1, 512, 9, 9]            (2,359,808)
│    └─ReLU: 2-28                        [1, 512, 9, 9]            --
│    └─Conv2d: 2-29                      [1, 512, 9, 9]            (2,359,808)
│    └─ReLU: 2-30                        [1, 512, 9, 9]            --
│    └─MaxPool2d: 2-31                   [1, 512, 4, 4]            --
├─Flatten: 1-2                           [1, 8192]                 --
├─Sequential: 1-3                        [1, 1]                    --
│    └─Linear: 2-32                      [1, 256]                  2,097,408
│    └─ReLU: 2-33                        [1, 256]                  --
│    └─Linear: 2-34                      [1, 1]                    257
│    └─Sigmoid: 2-35                     [1, 1]                    --
==========================================================================================
Total params: 16,812,353
Trainable params: 2,097,665
Non-trainable params: 14,714,688
Total mult-adds (G): 6.62
==========================================================================================
Input size (MB): 0.27
Forward/backward pass size (MB): 47.95
Params size (MB): 67.25
Estimated Total Size (MB): 115.47
==========================================================================================

Fine-tuning#

  • Add your custom network on top of an already trained base network.

  • Freeze the base network, but unfreeze the last block of conv layers.

class FineTunedClassifier(nn.Module):
    def __init__(self, conv_base, unfreeze_from=10): 
        super(FFEClassifier, self).__init__()
        self.conv_base = conv_base  # Pretrained feature extractor

        # Freeze all layers first
        for param in self.conv_base.parameters():
            param.requires_grad = False  
        # Unfreeze deeper layers
        for layer in list(self.conv_base.children())[unfreeze_from:]:
            for param in layer.parameters():
                param.requires_grad = True

        self.flatten = nn.Flatten()  # Flatten feature maps
        self.fc = nn.Sequential(
            nn.Linear(512 * 4 * 4, 256),  # Adjust input size based on feature maps
            nn.ReLU(),
            nn.Linear(256, 1))

Visualized

ml ml
Hide code cell source
class FFEClassifier(nn.Module):
    def __init__(self, conv_base, unfreeze_from=10):  # Unfreeze from a specific layer index
        super(FFEClassifier, self).__init__()
        self.conv_base = conv_base  # Pretrained feature extractor

        # Unfreeze only some layers (Fine-tuning)
        for param in self.conv_base.parameters():
            param.requires_grad = False  # Freeze all layers first

        # Unfreeze deeper layers (from `unfreeze_from` index onward)
        for layer in list(self.conv_base.children())[unfreeze_from:]:
            for param in layer.parameters():
                param.requires_grad = True

        self.flatten = nn.Flatten()  # Flatten feature maps
        self.fc = nn.Sequential(
            nn.Linear(512 * 4 * 4, 256),  # Adjust input size based on feature maps
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()  # Outputs probability (0 to 1)
        )

    def forward(self, x):
        x = self.conv_base(x)  # Now includes gradients for unfrozen layers
        x = self.flatten(x)
        x = self.fc(x)
        return x

# Example usage:
conv_base = vgg16(pretrained=True).features  # Load VGG16 feature extractor
ft_model = FFEClassifier(conv_base, unfreeze_from=10)  # Unfreeze layers from index 10 onward
Hide code cell source
summary(ft_model, input_size=(1, 3, 150, 150))  # Correct input shape for VGG16
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
FFEClassifier                            [1, 1]                    --
├─Sequential: 1-1                        [1, 512, 4, 4]            --
│    └─Conv2d: 2-1                       [1, 64, 150, 150]         (1,792)
│    └─ReLU: 2-2                         [1, 64, 150, 150]         --
│    └─Conv2d: 2-3                       [1, 64, 150, 150]         (36,928)
│    └─ReLU: 2-4                         [1, 64, 150, 150]         --
│    └─MaxPool2d: 2-5                    [1, 64, 75, 75]           --
│    └─Conv2d: 2-6                       [1, 128, 75, 75]          (73,856)
│    └─ReLU: 2-7                         [1, 128, 75, 75]          --
│    └─Conv2d: 2-8                       [1, 128, 75, 75]          (147,584)
│    └─ReLU: 2-9                         [1, 128, 75, 75]          --
│    └─MaxPool2d: 2-10                   [1, 128, 37, 37]          --
│    └─Conv2d: 2-11                      [1, 256, 37, 37]          295,168
│    └─ReLU: 2-12                        [1, 256, 37, 37]          --
│    └─Conv2d: 2-13                      [1, 256, 37, 37]          590,080
│    └─ReLU: 2-14                        [1, 256, 37, 37]          --
│    └─Conv2d: 2-15                      [1, 256, 37, 37]          590,080
│    └─ReLU: 2-16                        [1, 256, 37, 37]          --
│    └─MaxPool2d: 2-17                   [1, 256, 18, 18]          --
│    └─Conv2d: 2-18                      [1, 512, 18, 18]          1,180,160
│    └─ReLU: 2-19                        [1, 512, 18, 18]          --
│    └─Conv2d: 2-20                      [1, 512, 18, 18]          2,359,808
│    └─ReLU: 2-21                        [1, 512, 18, 18]          --
│    └─Conv2d: 2-22                      [1, 512, 18, 18]          2,359,808
│    └─ReLU: 2-23                        [1, 512, 18, 18]          --
│    └─MaxPool2d: 2-24                   [1, 512, 9, 9]            --
│    └─Conv2d: 2-25                      [1, 512, 9, 9]            2,359,808
│    └─ReLU: 2-26                        [1, 512, 9, 9]            --
│    └─Conv2d: 2-27                      [1, 512, 9, 9]            2,359,808
│    └─ReLU: 2-28                        [1, 512, 9, 9]            --
│    └─Conv2d: 2-29                      [1, 512, 9, 9]            2,359,808
│    └─ReLU: 2-30                        [1, 512, 9, 9]            --
│    └─MaxPool2d: 2-31                   [1, 512, 4, 4]            --
├─Flatten: 1-2                           [1, 8192]                 --
├─Sequential: 1-3                        [1, 1]                    --
│    └─Linear: 2-32                      [1, 256]                  2,097,408
│    └─ReLU: 2-33                        [1, 256]                  --
│    └─Linear: 2-34                      [1, 1]                    257
│    └─Sigmoid: 2-35                     [1, 1]                    --
==========================================================================================
Total params: 16,812,353
Trainable params: 16,552,193
Non-trainable params: 260,160
Total mult-adds (G): 6.62
==========================================================================================
Input size (MB): 0.27
Forward/backward pass size (MB): 47.95
Params size (MB): 67.25
Estimated Total Size (MB): 115.47
==========================================================================================

Take-aways#

  • Convnets are ideal for addressing image-related problems.

  • They learn a hierarchy of modular patterns and concepts to represent the visual world.

  • Representations are easy to inspect

  • Data augmentation helps fight overfitting

  • You can use a pretrained convnet to build better models via transfer learning