07 Torchvision Data Augmentation for More Diverse Data

07 Torchvision Data Augmentation for More Diverse Data #

Hello, I’m Fang Yuan.

In the previous lesson, we took the first step in training - data loading. We got a preliminary understanding of Torchvision and learned how to use Torchvision to load data. However, it is not enough to just load the images from the dataset. During the training process, the neural network model expects the data in the form of Tensors, not PIL objects. Therefore, we need to preprocess the data, such as converting the image format.

At the same time, the loaded image data may need to undergo a series of image transformations and augmentations, such as cropping borders, adjusting image proportions and sizes, normalization, etc., to help the model better learn the features of the data. All these operations can be done using the torchvision.transforms tool.

Today, we will learn how to use Torchvision to perform data preprocessing and image transformation and augmentation.

torchivision.transforms - Image Processing Tool #

The torchvision.transforms package in the Torchvision library provides common image operations for both Tensor and PIL Image objects, including random cropping, rotation, and data type conversions.

Based on the functionality of torchvision.transforms, these operations can be roughly categorized into the following classes: data type conversion, transformations and combinations of transformations on PIL.Image and Tensor objects. Let’s learn about these operations one by one.

Data Type Conversion #

In the previous lesson, we learned how to read images from a dataset, and the data we obtained was in the form of PIL Image objects. However, during the model training phase, we need to pass in data in the form of Tensor in order for the neural network to perform computations.

So how do we convert data in PIL.Image or Numpy.ndarray format to Tensor format? This requires using the transforms.ToTensor() class.

Conversely, to convert data in Tensor or Numpy.ndarray format to PIL.Image format, we use the transforms.ToPILImage(mode=None) class. This is the inverse operation of ToTensor and can convert Tensor or numpy arrays back to PIL.Image objects.

The parameter mode represents the mode of the PIL.Image. If mode is None (default value), it will be inferred based on the dimensions of the input data:

  • If the input has 3 channels, mode will be ‘RGB’.
  • If the input has 4 channels, mode will be ‘RGBA’.
  • If the input has 2 channels, mode will be ‘LA’.
  • If the input has 1 channel, the mode will be determined based on the type of the input data.

Image

Now that we’ve covered the usage, let’s look at a specific example to deepen our understanding. Taking the logo image of Geek Time (with the file name ‘jk.jpg’) as an example, we will perform the conversion between different data types. The specific code is as follows:

from PIL import Image
from torchvision import transforms

img = Image.open('jk.jpg')
display(img)
print(type(img)) # The type of img is PIL.JpegImagePlugin.JpegImageFile

# Convert PIL.Image to Tensor
img1 = transforms.ToTensor()(img)
print(type(img1)) # The type of img1 is torch.Tensor

# Convert Tensor to PIL.Image
img2 = transforms.ToPILImage()(img1)
print(type(img2)) # The type of img2 is PIL.Image.Image

First, we read the image and check its type, which is PIL.JpegImagePlugin.JpegImageFile. It’s important to note that the class PIL.JpegImagePlugin.JpegImageFile is a subclass of PIL.Image.Image. Then, we use transforms.ToTensor() to convert the PIL.Image to Tensor. Finally, we convert the Tensor back to PIL.Image using transforms.ToPILImage().

Transformations on PIL.Image and Tensor #

torchvision.transforms provides a variety of image transformation methods, such as resizing, cropping, flipping, etc. These image transformation operations can accept multiple data formats, not only directly transforming PIL format images but also transforming Tensors without the need for additional data type conversion.

Let’s look at them one by one.

Resize #

Resizes the input PIL Image or Tensor to the given size. The definition is as follows:

torchvision.transforms.Resize(size, interpolation=2)

Let’s look at the relevant parameters:

  • size: The desired output size. If size is a tuple like (h, w), the output size of the image will be matched with it. If size is an int, the smaller edge of the image will be matched to this integer, and the other edge will be scaled proportionally.
  • interpolation: Interpolation algorithm. It is an integer representing the PIL.Image interpolation method. The default value is 2, representing PIL.Image.BILINEAR.

Please note whether size is a tuple or an integer.

Let me explain further: during training, we usually resize images to a certain size, such as 128x128 or 256x256. If you directly specify the resized height and width, there is no problem. But if the specified size is an integer, the longer edge will be scaled proportionally.

After resizing, we usually perform a crop operation to a specified size. For images with similar height and width, this is not a problem. However, if there is a significant difference in height and width, it may crop out a lot of useful information. We will encounter this issue in more detail in the image classification section later.

Let’s take the logo image of Geek Time as an example to see the effect of the resize operation:

from PIL import Image
from torchvision import transforms

# Define the resize operation
resize_img_oper = transforms.Resize((200,200), interpolation=2)

# Original image
orig_img = Image.open('jk.jpg')
display(orig_img)

# Resized image
img = resize_img_oper(orig_img)
display(img)

First, we define a resize operation and set the desired size to (200, 200). Then we apply the Resize transformation to the image of the Geek Time logo. The original image and the resized image are shown below:

Image

Crop #

torchvision.transforms provides various crop methods, such as center crop, random crop, five-crop, etc. Let’s take a look at their definitions one by one.

First, center crop, as the name suggests, crops the specified PIL Image or Tensor at the center. Its definition is as follows:

torchvision.transforms.CenterCrop(size)

Here, size represents the expected output cropping size. If size is a tuple like (h, w), the cropped image size will match it exactly. If size is an int, the cropped image will be a square of size (size, size).

Next, random crop randomly crops the specified PIL Image or Tensor at a random position. Its definition is as follows:

torchvision.transforms.RandomCrop(size, padding=None)

Here, size represents the expected output cropping size, similar to above. padding represents the optional padding on each border of the image. The default value is None, which means no padding. Usually, we don’t use the padding parameter, at least for me.

Lastly, let’s talk about FiveCrop. It takes the given PIL Image or Tensor and crops it into five pieces from the four corners and the center. Its definition is as follows:

torchvision.transforms.FiveCrop(size)

size can be an int or a tuple, same as before.

Now that we have mastered the definitions and parameter usage of various crop operations, let’s see how to call these crop operations with the following code:

from PIL import Image
from torchvision import transforms

# Define crop operations
center_crop_oper = transforms.CenterCrop((60,70))
random_crop_oper = transforms.RandomCrop((80,80))
five_crop_oper = transforms.FiveCrop((60,70))

# Original image
orig_img = Image.open('jk.jpg')
display(orig_img)

# Center crop
img1 = center_crop_oper(orig_img)
display(img1)

# Random crop
img2 = random_crop_oper(orig_img)
display(img2)
# Crop at corners and center
imgs = five_crop_oper(orig_img)
for img in imgs:
    display(img)

The process is similar to Resize operation. First, we define the cropping operation, and then we apply different cropping operations to the GeekTime logo image. - The specific cropping effects are shown in the table below.

Image

Flipping #

Next, let’s take a look at flipping operations. torchvision.transforms provides two flipping operations: randomly flipping the image horizontally with a certain probability, and randomly flipping the image vertically with a certain probability. Let’s look at their definitions respectively.

Randomly flip the image horizontally with a probability of p, defined as:

torchvision.transforms.RandomHorizontalFlip(p=0.5)

Randomly flip the image vertically with a probability of p, defined as:

torchvision.transforms.RandomVerticalFlip(p=0.5)

Here, p represents the probability of flipping, which is set to 0.5 by default. - Random flipping is convenient for data augmentation. If you want to ensure that the flipping operation is always performed, you can set p to 1.

Taking the GeekTime logo image as an example, the code for flipping the image is as follows.

from PIL import Image
from torchvision import transforms 

# Define flipping operations
h_flip_oper = transforms.RandomHorizontalFlip(p=1)
v_flip_oper = transforms.RandomVerticalFlip(p=1)

# Original image
orig_img = Image.open('jk.jpg') 
display(orig_img)

# Horizontal flipping
img1 = h_flip_oper(orig_img)
display(img1)
# Vertical flipping
img2 = v_flip_oper(orig_img)
display(img2)

The flipping effects are shown in the table below.

Image

Transformation on Tensors only #

In the current version of Torchvision (v0.10.0), various image transformation operations are now supported for both PIL Image and Tensor types at the same time. Therefore, there are very few transformation operations specifically for Tensor, only four of them: LinearTransformation, Normalize, RandomErasing, and ConvertImageDtype.

Here we will focus on the most commonly used operation: normalization. You can refer to the official documentation for the other three operations.

Normalization #

Normalization refers to subtracting the mean of each data point from the data point and then dividing it by the standard deviation of the data point. The mathematical formula for calculation is as follows:

\[output=(input-mean)/std\]

To normalize an image, we normalize each channel of the image using its mean and standard deviation. The purpose of doing this is to ensure that the distribution of all images in the dataset is similar, which makes training easier in terms of convergence speed and training effectiveness.

Let me explain it to you: First of all, normalization is a common practice and can be understood as the result of training after normalization compared to training without normalization, which is more likely to be better.

After subtracting 50 from all pixels of GeekTime logo, we get the image below.

Image

For us humans, we can still tell that this is the GeekTime logo. But for a computer (i.e. a convolutional neural network), it may not be able to recognize it because a convolutional neural network extracts features based on pixel values of the image, and the pixel values of these two images are different, so why should the neural network think they are the same image?

However, the standardized data will avoid this problem. After standardization, the data will be mapped to the same interval, even though some pixel values of an image in a category may differ, but the distributions of those pixel values are similar.

torchvision.transforms provides a function for normalizing tensors, which is defined as follows.

torchvision.transforms.Normalize(mean, std, inplace=False)

Where each parameter has the following meanings:

  • mean: The mean value of each channel.
  • std: The standard deviation of each channel.
  • inplace: Whether to perform the operation in place. The default value is False.

Taking the GeekTime logo image as an example, let’s see what happens when we normalize the image with mean and standard deviation of (0.5, 0.5, 0.5).

from PIL import Image
from torchvision import transforms 

# Define normalization operation
norm_oper = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

# Original image
orig_img = Image.open('jk.jpg') 
display(orig_img)

# Convert image to Tensor
img_tensor = transforms.ToTensor()(orig_img)

# Normalize
tensor_norm = norm_oper(img_tensor)

# Convert Tensor to image
img_norm = transforms.ToPILImage()(tensor_norm)
display(img_norm)

The process in the above code is as follows: First, we define the normalization operation with mean and standard deviation of (0.5, 0.5, 0.5). Then we convert the original image to a Tensor, perform normalization on the Tensor, and finally convert the Tensor back to an image for display.

The normalization effect is shown in the table below.

Image

Combining Transformations #

In fact, all the operations mentioned earlier can be combined using the Compose class to perform consecutive operations.

The Compose class combines multiple transformations together, and its definition is as follows.

torchvision.transforms.Compose(transforms)

Where transforms is a list of Transform objects, representing the transformations to be combined. - Let’s try it out with an example. If we want to resize the image to 200x200 pixels and then randomly crop it into an 80x80 square, we can combine the Resize and RandomCrop transformations. The specific code is shown below.

from PIL import Image
from torchvision import transforms 

# Original image
orig_img = Image.open('jk.jpg') 
display(orig_img)

# Define composed operation
composed = transforms.Compose([transforms.Resize((200, 200)),
                               transforms.RandomCrop(80)])

# Image after composed operation
img = composed(orig_img)
display(img)

The result of running the code is shown in the table below, and I recommend you to try it out yourself.

Image

Using Compose with Datasets #

The Compose class is a class that we will frequently use in actual projects. When combined with the torchvision.datasets package, we can perform image transformations and data augmentation operations while reading the dataset. Let’s take a look together.

In the previous lesson, do you remember the parameter “transform” when we used torchvision.datasets to load the MNIST dataset? It is used for preprocessing operations on images, such as data augmentation, normalization, rotation, or scaling. The “transform” parameter can accept a torchvision.transforms operation or a combination of operations defined by the Compose class.

In the previous lesson, when we loaded the MNIST dataset, the image data obtained directly was of type PIL.Image.Image. However, in cases where we need to train models for tasks like handwriting digit recognition, the model expects data of type Tensor, not PIL objects. In this case, we can use the “transform” parameter to perform type conversion on the data while it is being read, so that the data obtained can be directly of type Tensor.

Not only can we convert the data type, but we can also add data augmentation operations such as normalization. We just need to use the Compose class mentioned above to combine them. In this way, while reading the data, we also complete a series of operations such as data preprocessing and data augmentation.

Let’s take the example of loading the MNIST dataset to see how we can perform data preprocessing and other operations while reading the data. The specific code is as follows.

from torchvision import transforms
from torchvision import datasets

# Define a transform
my_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

# Load the MNIST dataset with data transformation
mnist_dataset = datasets.MNIST(root='./data',
                               train=False,
                               transform=my_transform,
                               target_transform=None,
                               download=True)

# Check the data type after transformation
item = mnist_dataset.__getitem__(0)
print(type(item[0]))
'''
Output:
<class 'torch.Tensor'>
'''

Of course, the MNIST dataset is very simple, and it works very well even without any processing. But it is indeed suitable for learning and can be used for various experiments.

Next, let’s take a look at the transforms used in the image classification project to get a sense of what actual transforms look like:

transform = transforms.Compose([
    transforms.RandomResizedCrop(dest_image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

This is the transform I use in my project. There are many methods for data augmentation, but in my experience, using more does not necessarily result in better performance.

Conclusion #

Congratulations on completing this lesson. Let me summarize it for you.

The main focus of today’s lesson was on the use of torchvision.transforms tools. This includes common image processing operations and how to combine them with torchvision.datasets.

Common image processing operations include data type conversion, image size changes, cropping, flipping, normalization, and more. The Compose class can also combine multiple transformation operations into a list of Transform objects.

By combining torchvision.transforms with torchvision.datasets, you can perform a series of image transformations and data augmentation operations while loading the data. This not only allows you to feed the data directly into your model for training, but also speeds up the convergence of the model, enabling it to better learn the features of the data.

Of course, in real-world projects, we will have our own data instead of using the publicly available datasets provided by torchvision.datasets. However, the torchvision.transforms we discussed today can still be used in our custom datasets. I will continue to explain this in more detail in the image classification practice.

In the next lesson, we will introduce other interesting features in Torchvision, including instantiating classic network models and other useful functions.

Practice for each lesson #

What is the function of the transforms module in Torchvision?

Feel free to communicate and discuss with me in the comments section. You are also encouraged to share this lesson with your friends and try out various functions of Torchvision together.