8  Hands-On Lab: PyTorch

Overview

This hands-on lab session offers participants practical experience with PyTorch for building, training, and evaluating neural network models. Participants will work with a sample dataset, load a pre-trained model, and fine-tune it to enhance performance. The session will guide participants through the building blocks of a deep learning application, including data, models, loss functions, optimization algorithms, training and evaluation, inference and visualization.

Note

You can access the Jupyter notebook for this hands-on lab on Google Colab.

8.1 Data

8.1.1 Dataset overview

In this hands-on lab, we will use a sample RTS (Retrogressive Thaw Slumps) dataset from Dr. Yili Yang’s research. While the RTS dataset was originally used for semantic segmentation, we will repurpose it for a classification task. The goal is to classify the number of RTS present in each image, with counts ranging from 1 to 10, which will serve as the ground truth for our model.

The dataset structure and required files for this hands-on lab are as follows:

cyber2a
│--- rts
│    │--- images  # Folder containing RGB images
│    │    │--- train_nitze_000.jpg
│    │    │--- train_nitze_001.jpg
│    │    │--- ...
│--- data_split.json
│--- rts_cls.json
  • data_split.json: A dictionary with two keys: train and valtest:
    • train: A list of image filenames for training.
    • valtest: A list of image filenames for validation and testing.
  • rts_cls.json: A dictionary with image filenames as keys and the number of RTS in each image as values.

8.1.2 Download the dataset

To download the dataset, run the following commands:

!wget --content-dispositio https://www.dropbox.com/scl/fi/1pz52tq3puomi0185ccyq/cyber2a.zip?rlkey=3dgf4gfrj9yk1k4p2znn9grso&st=bapbt1bq&dl=0

!unzip -o cyber2a.zip

8.1.3 Visualize the dataset

Let’s visualize the dataset by displaying one image and its corresponding label:

import os
import json 

from PIL import Image
import matplotlib.pyplot as plt

# Define the directory where images are stored
img_dir = "cyber2a/rts/images/"

# Load the data split file to get lists of training and validation/test images
with open("cyber2a/data_split.json", 'r') as f:
    data_split = json.load(f)

# Retrieve the list of training images
img_list = data_split["train"]

# Load the image labels, where each image name maps to the number of RTS in the image
with open("cyber2a/rts_cls.json", 'r') as f:
    img_labels = json.load(f)

# Select the first image file name from the training list and get its corresponding label
img_name = img_list[0]
img_label = img_labels[img_name]

# Print the image file name and its corresponding number of RTS
print(f"Image Name: {img_name}, Number of RTS: {img_label}")

# Construct the full path to the image file
img_path = os.path.join(img_dir, img_name)

# Open the image and convert it to RGB format
image = Image.open(img_path).convert("RGB")

# Convert the label to 0-indexed format for classification tasks
label = int(img_label) - 1

# Display the image using matplotlib
fig, ax = plt.subplots()
ax.imshow(image)
ax.axis('off')  # Hide the axis
ax.set_title(f"Label: {label}")  # Set the title to the 0-indexed label

# Show the plot
plt.show()

8.1.4 Build a custom dataset

To build a custom dataset, we will create a PyTorch dataset class that loads the images and their corresponding labels. The dataset class will implement the following methods:

  • __init__: Initialize the dataset by loading the image filenames and labels.
  • __len__: Return the total number of images in the dataset.
  • __getitem__: Load an image and its corresponding label based on the index.
from torch.utils.data import Dataset

class RTSDataset(Dataset):
    def __init__(self, split, transform=None):
        """
        Args:
            split (str): One of 'train' or 'valtest' to specify the dataset split.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        # Define the directory where images are stored
        self.img_dir = "cyber2a/rts/images/"
        
        # Load the list of images based on the split (train/valtest)
        with open("cyber2a/data_split.json") as f:
            data_split = json.load(f)
            
        if split == 'train':
            self.img_list = data_split['train']
        elif split == 'valtest':
            self.img_list = data_split['valtest']
        else:
            raise ValueError("Invalid split: choose either 'train' or 'valtest'")
    
        # Load the image labels
        with open("cyber2a/rts_cls.json") as f:
            self.img_labels = json.load(f)

        # Store the transform to be applied to images
        self.transform = transform

    def __len__(self):
        """Return the total number of images."""
        return len(self.img_list)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the image to retrieve.
        
        Returns:
            tuple: (image, label) where image is the image tensor and label is the corresponding label.
        """
        # Retrieve the image name using the index
        img_name = self.img_list[idx]
      
        # Construct the full path to the image file
        img_path = os.path.join(self.img_dir, img_name)
        
        # Open the image and convert it to RGB format
        image = Image.open(img_path).convert('RGB')
        
        # Get the corresponding label and adjust it to be 0-indexed
        label = self.img_labels[img_name] - 1

        # apply transform if specified
        if self.transform:
            image = self.transform(image)

        return image, label

8.1.5 Test the custom dataset

To test the custom dataset, we will create an instance of the RTSDataset class for the training split and display the first image and its corresponding label:

def display_sample_images(dataset, num_images=3):
    """
    Display sample images from the dataset.

    Args:
        dataset (Dataset): The dataset to sample images from.
        num_images (int): Number of images to display.
        save_path (str): Path to save the displayed images.
    """
    data, label = dataset[0]
    if type(data) is dict:
        num_modalities = len(data)
        fig, axs = plt.subplots(num_modalities, num_images, figsize=(20, 5))
        for i in range(num_images):
            data, label = dataset[i]
            for j, modality in enumerate(data):
                axs[j, i].imshow(data[modality])
                if j == 0:
                    axs[j, i].set_title(f"label: {label}")
                else:
                    axs[j, i].set_title(f"modality: {modality}")
                axs[j, i].axis("off")

    else:
        fig, axs = plt.subplots(1, num_images, figsize=(20, 5))
        for i in range(num_images):
            data, label = dataset[i]
            axs[i].imshow(data)
            axs[i].set_title(f"Label: {label}")
            axs[i].axis("off")

    plt.show()

# Create the training dataset
train_dataset = RTSDataset("train")

# Display and save sample images from the training dataset
display_sample_images(train_dataset)

8.1.6 Define data transforms and data loaders

To prepare the data for training, we will define data transforms to normalize the images and convert them to PyTorch tensors. We will also create data loaders to load the data in batches during training and validation.

import torch
import torchvision.transforms as T

# Define the transform for the dataset
transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create the training and validation datasets with transforms
train_dataset = RTSDataset("train", transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)

val_dataset = RTSDataset("valtest", transform=transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False)

8.2 Models

8.2.1 Load a pre-trained model

In this hands-on lab, we will use a pre-trained ResNet-18 model as the backbone for our classification task. We will load the pre-trained ResNet-18 model from the torchvision library and modify the final fully connected layer to output 10 classes corresponding to the number of RTS in the images.

from torchvision import models # https://pytorch.org/vision/stable/models.html
from torchvision.models.resnet import ResNet18_Weights

# Load the pretrained ResNet18 model
model = models.resnet18(weights=ResNet18_Weights.DEFAULT)

# Modify the final layer to match the number of classes
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 10)

# print the model to observe the new `fc` layer
print(model)

8.3 Loss functions

For the classification task, we will use the cross-entropy loss function, which is commonly used for multi-class classification problems.

import torch

# Define the loss function
criterion = torch.nn.CrossEntropyLoss()

8.4 Optimization algorithms

We will use the SGD (Stochastic Gradient Descent) optimizer to train the model.

import torch.optim as optim

# Define the optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

8.5 Training and evaluation

8.5.1 Define the training and evaluation functions

The training function will iterate over the training dataset, compute the loss, backpropagate the gradients, and update the model parameters. The evaluation function will iterate over the validation dataset and compute the accuracy of the model.

import torch
from tqdm import tqdm

def train(model, criterion, optimizer, train_loader, val_loader, num_epochs=5):
    """
    Train the model.
    
    Args:
        model: The model to train.
        criterion: The loss function.
        optimizer: The optimizer.
        train_loader: DataLoader for the training data.
        val_loader: DataLoader for the validation data.
        num_epochs (int): Number of epochs to train.
    
    Returns:
        model: The trained model.
    """
    for epoch in range(num_epochs):
        # Set model to training mode
        model.train()
        running_loss = 0.0
        for i, data in enumerate(tqdm(train_loader)):
            inputs, labels = data
            
            # get model's device
            device = next(model.parameters()).device
            
            # Move data to the appropriate device
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass to get model outputs
            outputs = model(inputs)

            # Compute the loss
            loss = criterion(outputs, labels)
            # Backward pass to compute gradients
            loss.backward()
            # Update model parameters
            optimizer.step()

            # Accumulate the running loss
            running_loss += loss.item()
        
        epoch_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")
        
        # Validation phase
        # set the model to validation mode
        model.eval()
        correct = 0
        total = 0
        # Disable gradient computation for validation
        with torch.no_grad():
            for data in val_loader:
                images, labels = data
                # Move validation data to the appropriate device
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                # Get the predicted class
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f"Validation accuracy: {100 * correct / total}%")

    return model

8.5.2 Train the model

Let’s train the model using the training and evaluation functions defined above.

# move model to gpu is available 
if torch.cuda.is_available():
    model = model.to('cuda')

model = train(model, criterion, optimizer, train_loader, val_loader, num_epochs=5)

8.6 Inference

To perform inference on new images, we will define a function that takes an image as input, preprocesses it, and passes it through the model to get the predicted class.

def predict_image(model, image_path):
    """
    Predict the class of a sample image.
    
    Args:
        model: The trained model.
        image_path (str): Path to the image to predict.
    
    Returns:
        int: Predicted class label.
    """
    transform = T.Compose([
        T.Resize((256, 256)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = Image.open(image_path).convert("RGB")
    # Apply the transformations and add a batch dimension
    image = transform(image).unsqueeze(0)
    device = next(model.parameters()).device
    image = image.to(device)

    model.eval() # Set the model to evaluation mode
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
        return predicted.item()

img_name = "valtest_yg_070.jpg"
img_dir = "./cyber2a/rts/images"
img_path = os.path.join(img_dir, img_name)

predicted_class = predict_image(model, img_path)

print(f"Predicted class for {img_name}: {predicted_class}")

8.7 Visualization

To visualize the model’s predictions, we will display a sample image from the validation set along with the predicted class.

import matplotlib.pyplot as plt
import cv2
import numpy as np

def display_image_with_annotations(image_name, image_folder):
    """
    Display an image with its annotations.

    Parameters:
    - image_name: str, the name of the image file to display.
    - image_folder: str, the folder where the images are stored.

    Returns:
    - cv2_image: The annotated image.
    """
    # Load the COCO annotations from the JSON file
    with open("cyber2a/rts_coco.json", "r") as f:
        rts_coco = json.load(f)

    # Get the image ID for the image
    image_id = None
    for image in rts_coco["images"]:
        if image["file_name"] == image_name:
            image_id = image["id"]
            break

    if image_id is None:
        raise ValueError(f"Image {image_name} not found in COCO JSON file.")

    # Get the annotations for the image
    annotations = []
    for annotation in rts_coco["annotations"]:
        if annotation["image_id"] == image_id:
            annotations.append(annotation["segmentation"])

    # Read the image
    cv2_image = cv2.imread(f"{image_folder}/{image_name}")
    if cv2_image is None:
        raise FileNotFoundError(
            f"Image file {image_name} not found in folder {image_folder}."
        )

    # Overlay the polygons on top of the image
    for annotation in annotations:
        for polygon in annotation:
            # Reshape polygon to an appropriate format for cv2.polylines
            polygon = np.array(polygon, dtype=np.int32).reshape((-1, 2))
            cv2.polylines(
                cv2_image, [polygon], isClosed=True, color=(0, 255, 0), thickness=2
            )

    cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)

    return cv2_image


image = display_image_with_annotations(img_name, img_dir)

fig, ax = plt.subplots()
ax.imshow(image)
ax.set_title(f'number of predicted RTS: {predicted_class + 1}')
plt.show()