Image Segmentation Tutorial — Identifying Brain Tumors using PyTorch

Arham Khan
8 min readAug 29, 2021

In this piece, we explore what image segmentation is, how we can train a model to segment images, and show example code for training an image segmentation model using PyTorch.

Image of model brain
Photo by Robina Weermeijer on Unsplash

Overview

What is image segmentation?

One of the most common image processing problems is image segmentation — we want to identify and mark an area of our image. For example, we might want to identify where lanes and signs on a road are, what items on a shelf need to be restocked, or we may want to identify blockages in the heart and lungs.

All of these tasks involve localizing one or more objects in an image. Crucially, this is more complicated than simply classifying whether an object exists in an image, we want to determine where exactly in the image the object lies. Luckily this is a job perfect for Neural Networks!

How does segmentation work?

The goal of image segmentation is to identify the specific area in the image where an object resides. We can either output a bounding box vector with three components — the top left corner coordinate, the width, and the height — or we can perform pixel-wise segmentation where we classify whether the object is contained in each pixel.

The bounding box approach is preferable if we want to reduce the size or focus of the image — we might do this if the segmentation is part of a larger pipeline of models, we could feed crops of each segmented region to a classifier. Sometimes we also don’t need a high-resolution segmentation, it might be enough to know where an object is more generally. This is referred to as Object Localization.

If we want to know exactly which pixels an object is contained within then we use pixel-wise or Semantic Segmentation. This is particularly useful for applications that require a high degree of accuracy — we might use this in pedestrian detection, lane detection, or in analyzing MRI scans for tumors. In this case, we want to classify whether each pixel is of a certain class so our output is a vector that is the same size as the image, representing the probability that each pixel belongs to a given class. Here we will focus mainly on semantic segmentation.

Visual comparison between classification, localization, and segmentation
Source: https://manipulation.csail.mit.edu/segmentation.html

Notice the connection between classification and segmentation problems. Semantic segmentation problems are a special class of classification problems where we want to assign a classification to each pixel, instead of classifying the image as a whole.

Binary Segmentation

Input image of horse alongside the segmented mask
Example of Binary Segmentation. Source: https://www.researchgate.net/figure/Semantic-segmentation-examples-using-WASPnet_fig1_337771814

The output vector in segmentation is called an image mask since it can be overlayed onto the original image in order to highlight regions of importance. Let's consider an example.

Say we want to perform semantic segmentation on a 256x256x3 image with only one class. We want to segment cars against the background image. To accomplish this we would train a model which outputs a 256x256x1 mask vector where each of the elements in that vector corresponds to the probability that a pixel in the input image belongs to the car class.

For instance, if pixel [0,0] in the top left corner of the mask has a value of 0.8 this implies that pixel [0,0] in the top left corner of the image belongs to the car class with a probability of 0.8.

Since we only have two classes — car and background — this is referred to as binary segmentation. We use a sigmoid activation to ensure that all outputs are between zero and one. In order to make predictions we establish a threshold value, such as 0.5, over which we conclude that the pixel belongs to the given class.

Multiclass Segmentation

Example multiclass segmentation map showing street, cars, and pedestrians being segmented
Source: https://www.researchgate.net/figure/Example-of-generated-segmentation-masks-with-MobileNetV2-In-the-top-row-prediction-mask_fig3_337758902

Sometimes we want to segment multiple classes in a single image. Let's consider the situation where we want to segment pedestrians, the sidewalk, and the street simultaneously.

If our input image is 256x256x3 then we would output a 256x256x4 mask vector. The number of channels in the mask vector is determined by the number of classes so the first channel corresponds to the pedestrian class, the second to the sidewalk class, and the third to the street. We have a fourth channel to represent the background class.

We then form a probability distribution over the classes by performing a softmax operation on the channels. For example, let's consider pixel [0,0] in the mask. Then in the mask, element [0,0,0] (the first row, first column, and first channel) represents the probability that pixel [0,0] in the input image belongs to the pedestrian class.

Image showing visually how the mask channels correspond to a probability distribution over each class in relation to each pixel

Considering all four channels let elements [0,0,0], [0,0,1], [0,0,2], and [0,0,3] in the mask be [0.1, 0.6, 0.2, 0.1]. This distribution implies that pixel [0,0] in the input image belongs to the sidewalk class.

Loss Functions

Because segmentation is a subset of classification we often use the cross-entropy loss between our output mask and target mask. This assumes that the two are just probability distributions that should be as close to one another as possible.

Image of cross entropy loss. Negative sum over prob(y) times log(prob(y))
Cross-Entropy Loss. Source: https://androidkt.com/choose-cross-entropy-loss-function-in-keras/

We can also use the Intersection over Union (IoU) or Dice loss. This loss is sometimes preferred because it penalizes the model for not producing a mask that overlaps well with the provided label.

Dice loss, 2 times the intersection of the masks over the union of the masks.
Dice Loss. Source: https://devblogs.microsoft.com/cse/2018/07/05/satellite-images-segmentation-sustainable-farming/

Because the Dice loss is sometimes unstable, we may choose to combine it with cross-entropy in order to reap the benefits of smooth gradients from cross-entropy while also maintaining the objective of the Dice loss.

Semantic Segmentation Tutorial

Overview

Here we will explore applying semantic segmentation to the Brain MRI Segmentation dataset available on Kaggle. This is a binary segmentation task where we are asked to identify the location of glioma present in brain MRIs obtained from The Cancer Imaging Archive.

Example images with corresponding tumor masks
Example Images and Labels from the dataset.

Some scans have no tumors at all, this is part of the problem that the dataset poses. In practice what we would do to address this is to construct a pipeline that includes a classifier that initially tells us whether or not an image contains a tumor, then we would only feed images with tumors to our final segmentation model.

Since we are only concerned with the segmentation portion here, I will only show the methodology that we would use for segmenting images with tumors.

Reading Data

This is relatively simple, we just use our desired utility to read in images and place them in a PyTorch Dataset. This dataset can then be served to our model through a PyTorch Dataloader and then we are free to use it in training and validation as we wish.

In order to improve performance and make good use of our data, it is often a good idea to augment our dataset. This gives us a more diverse dataset, more training examples, and leads to a more robust model. Best of all, it costs nothing to augment existing data using PyTorch.

For segmentation tasks, it is common to augment data using techniques such as Affine Transformations, Elastic Transformations, and Pixel-Level Transformations.

Affine Transformations include flips, rotations, translations, shearing, or translations. Elastic Transformations are meant to preserve the geometry of the image while producing some deformation — this is particularly applicable in medicine as if we are considering tissue samples elastic deformations mimic the appearance of realistic deformations that the tissue could undergo. Pixel-Level Transformations do not affect the geometry of the image at all, we use these to correct intensity or color in order to enhance performance.

Model

Because it has already been shown to be effective in medical applications for segmentation, we will be using the UNet model. There is also an implementation of UNet++ in the Kaggle Kernel linked below.

Image of UNet architecture
UNet Architecture. Source: https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/

The basic idea of UNet is that we should propagate high-level features to later layers so that we can combine insights from various receptive fields into our final output. Our earlier layers capture features over a larger portion of the image and provide general context while later layers view a smaller portion of the image.

The model appends the high-level feature maps to the low-level feature maps so that high-level information propagates into deep layers of the network, enabling these features to provide context for the final prediction.

Training

Here we have a standard training loop. In each epoch, we update the network parameters and then note the validation loss.

Results

Below are some results of the experiments. These are generated using UNet on the augmented dataset alongside the Dice Loss. The results of various other experiments are available in the original Kaggle notebook.

In every experiment, the model tends to segment larger objects more easily but struggles on samples where the label is much smaller (see row six). This could be due to underrepresentation in the data, or the inherent difficulty in detecting very small objects. This could be addressed via further data augmentation, tuning of the weighting of the dice and cross-entropy loss, or by adding a loss penalty for incorrectly classifying background space.

Columns of images showing input from dataset and predictions from model
Results with UNet and Dice Loss. Pictured are the input images (left), the ground truth labels (center), and the network predictions (right).

Conclusion

Here we learned how image segmentation works and how to train a model in PyTorch to segment images. The original Kaggle Kernel contains more experiments, data processing code, and implementations of UNet and UNet++. Please leave some feedback, clap, and follow for more!

Resources

--

--

Arham Khan

I love helping people develop a strong intuition for math and science — making it effortless to learn groundbreaking concepts.