This hands-on lab session offers participants practical experience with PyTorch for building, training, and evaluating neural network models. Participants will work with a sample dataset, load a pre-trained model, and fine-tune it to enhance performance. The session will guide participants through the building blocks of a deep learning application, including data, models, loss functions, optimization algorithms, training and evaluation, inference and visualization.
Note
You can access the Jupyter notebook for this hands-on lab on Google Colab.
8.1 Data
8.1.1 Dataset overview
In this hands-on lab, we will use a sample RTS (Retrogressive Thaw Slumps) dataset from Dr. Yili Yang’s research. While the RTS dataset was originally used for semantic segmentation, we will repurpose it for a classification task. The goal is to classify the number of RTS present in each image, with counts ranging from 1 to 10, which will serve as the ground truth for our model.
The dataset structure and required files for this hands-on lab are as follows:
Let’s visualize the dataset by displaying one image and its corresponding label:
Visualize the dataset
import osimport json from PIL import Imageimport matplotlib.pyplot as plt# Define the directory where images are storedimg_dir ="cyber2a/rts/images/"# Load the data split file to get lists of training and validation/test imageswithopen("cyber2a/data_split.json", 'r') as f: data_split = json.load(f)# Retrieve the list of training imagesimg_list = data_split["train"]# Load the image labels, where each image name maps to the number of RTS in the imagewithopen("cyber2a/rts_cls.json", 'r') as f: img_labels = json.load(f)# Select the first image file name from the training list and get its corresponding labelimg_name = img_list[0]img_label = img_labels[img_name]# Print the image file name and its corresponding number of RTSprint(f"Image Name: {img_name}, Number of RTS: {img_label}")# Construct the full path to the image fileimg_path = os.path.join(img_dir, img_name)# Open the image and convert it to RGB formatimage = Image.open(img_path).convert("RGB")# Convert the label to 0-indexed format for classification taskslabel =int(img_label) -1# Display the image using matplotlibfig, ax = plt.subplots()ax.imshow(image)ax.axis('off') # Hide the axisax.set_title(f"Label: {label}") # Set the title to the 0-indexed label# Show the plotplt.show()
8.1.4 Build a custom dataset
To build a custom dataset, we will create a PyTorch dataset class that loads the images and their corresponding labels. The dataset class will implement the following methods:
__init__: Initialize the dataset by loading the image filenames and labels.
__len__: Return the total number of images in the dataset.
__getitem__: Load an image and its corresponding label based on the index.
Build a custom dataset
from torch.utils.data import Datasetclass RTSDataset(Dataset):def__init__(self, split, transform=None):""" Args: split (str): One of 'train' or 'valtest' to specify the dataset split. transform (callable, optional): Optional transform to be applied on a sample. """# Define the directory where images are storedself.img_dir ="cyber2a/rts/images/"# Load the list of images based on the split (train/valtest)withopen("cyber2a/data_split.json") as f: data_split = json.load(f)if split =='train':self.img_list = data_split['train']elif split =='valtest':self.img_list = data_split['valtest']else:raiseValueError("Invalid split: choose either 'train' or 'valtest'")# Load the image labelswithopen("cyber2a/rts_cls.json") as f:self.img_labels = json.load(f)# Store the transform to be applied to imagesself.transform = transformdef__len__(self):"""Return the total number of images."""returnlen(self.img_list)def__getitem__(self, idx):""" Args: idx (int): Index of the image to retrieve. Returns: tuple: (image, label) where image is the image tensor and label is the corresponding label. """# Retrieve the image name using the index img_name =self.img_list[idx]# Construct the full path to the image file img_path = os.path.join(self.img_dir, img_name)# Open the image and convert it to RGB format image = Image.open(img_path).convert('RGB')# Get the corresponding label and adjust it to be 0-indexed label =self.img_labels[img_name] -1# apply transform if specifiedifself.transform: image =self.transform(image)return image, label
8.1.5 Test the custom dataset
To test the custom dataset, we will create an instance of the RTSDataset class for the training split and display the first image and its corresponding label:
Test the custom dataset
def display_sample_images(dataset, num_images=3):""" Display sample images from the dataset. Args: dataset (Dataset): The dataset to sample images from. num_images (int): Number of images to display. save_path (str): Path to save the displayed images. """ data, label = dataset[0]iftype(data) isdict: num_modalities =len(data) fig, axs = plt.subplots(num_modalities, num_images, figsize=(20, 5))for i inrange(num_images): data, label = dataset[i]for j, modality inenumerate(data): axs[j, i].imshow(data[modality])if j ==0: axs[j, i].set_title(f"label: {label}")else: axs[j, i].set_title(f"modality: {modality}") axs[j, i].axis("off")else: fig, axs = plt.subplots(1, num_images, figsize=(20, 5))for i inrange(num_images): data, label = dataset[i] axs[i].imshow(data) axs[i].set_title(f"Label: {label}") axs[i].axis("off") plt.show()# Create the training datasettrain_dataset = RTSDataset("train")# Display and save sample images from the training datasetdisplay_sample_images(train_dataset)
8.1.6 Define data transforms and data loaders
To prepare the data for training, we will define data transforms to normalize the images and convert them to PyTorch tensors. We will also create data loaders to load the data in batches during training and validation.
Define data transforms and data loaders
import torchimport torchvision.transforms as T# Define the transform for the datasettransform = T.Compose([ T.Resize((256, 256)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])# Create the training and validation datasets with transformstrain_dataset = RTSDataset("train", transform=transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)val_dataset = RTSDataset("valtest", transform=transform)val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False)
8.2 Models
8.2.1 Load a pre-trained model
In this hands-on lab, we will use a pre-trained ResNet-18 model as the backbone for our classification task. We will load the pre-trained ResNet-18 model from the torchvision library and modify the final fully connected layer to output 10 classes corresponding to the number of RTS in the images.
Load a pre-trained model and modify the final layer
from torchvision import models # https://pytorch.org/vision/stable/models.htmlfrom torchvision.models.resnet import ResNet18_Weights# Load the pretrained ResNet18 modelmodel = models.resnet18(weights=ResNet18_Weights.DEFAULT)# Modify the final layer to match the number of classesnum_ftrs = model.fc.in_featuresmodel.fc = torch.nn.Linear(num_ftrs, 10)# print the model to observe the new `fc` layerprint(model)
8.3 Loss functions
For the classification task, we will use the cross-entropy loss function, which is commonly used for multi-class classification problems.
Define the loss function
import torch# Define the loss functioncriterion = torch.nn.CrossEntropyLoss()
8.4 Optimization algorithms
We will use the SGD (Stochastic Gradient Descent) optimizer to train the model.
Define the optimizer
import torch.optim as optim# Define the optimizeroptimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
8.5 Training and evaluation
8.5.1 Define the training and evaluation functions
The training function will iterate over the training dataset, compute the loss, backpropagate the gradients, and update the model parameters. The evaluation function will iterate over the validation dataset and compute the accuracy of the model.
Define the training and evaluation functions
import torchfrom tqdm import tqdmdef train(model, criterion, optimizer, train_loader, val_loader, num_epochs=5):""" Train the model. Args: model: The model to train. criterion: The loss function. optimizer: The optimizer. train_loader: DataLoader for the training data. val_loader: DataLoader for the validation data. num_epochs (int): Number of epochs to train. Returns: model: The trained model. """for epoch inrange(num_epochs):# Set model to training mode model.train() running_loss =0.0for i, data inenumerate(tqdm(train_loader)): inputs, labels = data# get model's device device =next(model.parameters()).device# Move data to the appropriate device inputs, labels = inputs.to(device), labels.to(device)# Zero the parameter gradients optimizer.zero_grad()# Forward pass to get model outputs outputs = model(inputs)# Compute the loss loss = criterion(outputs, labels)# Backward pass to compute gradients loss.backward()# Update model parameters optimizer.step()# Accumulate the running loss running_loss += loss.item() epoch_loss = running_loss /len(train_loader)print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")# Validation phase# set the model to validation mode model.eval() correct =0 total =0# Disable gradient computation for validationwith torch.no_grad():for data in val_loader: images, labels = data# Move validation data to the appropriate device images, labels = images.to(device), labels.to(device) outputs = model(images)# Get the predicted class _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item()print(f"Validation accuracy: {100* correct / total}%")return model
8.5.2 Train the model
Let’s train the model using the training and evaluation functions defined above.
Train the model
# move model to gpu is available if torch.cuda.is_available(): model = model.to('cuda')model = train(model, criterion, optimizer, train_loader, val_loader, num_epochs=5)
8.6 Inference
To perform inference on new images, we will define a function that takes an image as input, preprocesses it, and passes it through the model to get the predicted class.
Inference
def predict_image(model, image_path):""" Predict the class of a sample image. Args: model: The trained model. image_path (str): Path to the image to predict. Returns: int: Predicted class label. """ transform = T.Compose([ T.Resize((256, 256)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image = Image.open(image_path).convert("RGB")# Apply the transformations and add a batch dimension image = transform(image).unsqueeze(0) device =next(model.parameters()).device image = image.to(device) model.eval() # Set the model to evaluation modewith torch.no_grad(): outputs = model(image) _, predicted = torch.max(outputs, 1)return predicted.item()img_name ="valtest_yg_070.jpg"img_dir ="./cyber2a/rts/images"img_path = os.path.join(img_dir, img_name)predicted_class = predict_image(model, img_path)print(f"Predicted class for {img_name}: {predicted_class}")
8.7 Visualization
To visualize the model’s predictions, we will display a sample image from the validation set along with the predicted class.
Visualization
import matplotlib.pyplot as pltimport cv2import numpy as npdef display_image_with_annotations(image_name, image_folder):""" Display an image with its annotations. Parameters: - image_name: str, the name of the image file to display. - image_folder: str, the folder where the images are stored. Returns: - cv2_image: The annotated image. """# Load the COCO annotations from the JSON filewithopen("cyber2a/rts_coco.json", "r") as f: rts_coco = json.load(f)# Get the image ID for the image image_id =Nonefor image in rts_coco["images"]:if image["file_name"] == image_name: image_id = image["id"]breakif image_id isNone:raiseValueError(f"Image {image_name} not found in COCO JSON file.")# Get the annotations for the image annotations = []for annotation in rts_coco["annotations"]:if annotation["image_id"] == image_id: annotations.append(annotation["segmentation"])# Read the image cv2_image = cv2.imread(f"{image_folder}/{image_name}")if cv2_image isNone:raiseFileNotFoundError(f"Image file {image_name} not found in folder {image_folder}." )# Overlay the polygons on top of the imagefor annotation in annotations:for polygon in annotation:# Reshape polygon to an appropriate format for cv2.polylines polygon = np.array(polygon, dtype=np.int32).reshape((-1, 2)) cv2.polylines( cv2_image, [polygon], isClosed=True, color=(0, 255, 0), thickness=2 ) cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)return cv2_imageimage = display_image_with_annotations(img_name, img_dir)fig, ax = plt.subplots()ax.imshow(image)ax.set_title(f'number of predicted RTS: {predicted_class +1}')plt.show()