06 Torchvision the First Step of Data Reading and Training

06 Torchvision The First Step of Data Reading and Training #

Hello, I’m Fang Yuan.

Today, we will begin our study on model training. If we think of a model as a car, then its development process can be seen as a complete production process, where every step is connected and essential. These steps include data loading, network design, choosing optimization methods and loss functions, as well as some auxiliary tools. In the future, you will try to build your own luxury car or optimize the works of predecessors by standing on their shoulders.

Imagine, if you are not clear about the methods used in these basic steps, can you proceed smoothly? Therefore, the goal of this module is to lay a solid foundation. By studying this module, you will have a clear understanding of the rich APIs provided by PyTorch.

Torchvision is a Python package used in conjunction with PyTorch, which contains many image processing tools. We will start with data processing, taking the first step in learning PyTorch. In this lesson, we will first introduce the commonly used datasets and their loading methods in Torchvision. In the next two lessons, I will guide you to explore commonly used image processing methods and other interesting features of Torchvision.

Data Loading in PyTorch #

The first step in training is data loading. PyTorch provides a convenient data loading mechanism by using the combination of the Dataset class and the DataLoader class to obtain a data iterator. During training or prediction, the data iterator can output the required data for each batch and perform corresponding data preprocessing and data augmentation operations.

Now let’s take a look at the Dataset class and the DataLoader class.

Dataset Class #

The Dataset class in PyTorch is an abstract class that can be used to represent a dataset. We can customize the format, size, and other properties of the dataset by inheriting the Dataset class, and later it can be directly used by the DataLoader class.

In fact, whether using a custom dataset or an officially encapsulated dataset, they are essentially classes that inherit the Dataset class. When inheriting the Dataset class, we need to override at least the following methods:

  • init(): Constructor, can customize data loading methods and perform data preprocessing;
  • len(): Returns the size of the dataset;
  • getitem(): Retrieves a specific data item from the dataset.

Just looking at the principles may not be easy to understand, so let’s write a simple example to see how to use the Dataset class to define a Tensor type dataset.

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    # Constructor
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor
    # Returns the size of the dataset
    def __len__(self):
        return self.data_tensor.size(0)
    # Returns the indexed data and label
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]

In the code above, we define a dataset named MyDataset. In the constructor, we pass in the data and labels of Tensor type; in the len function, we simply return the size of the Tensor; in the getitem function, we return the indexed data and label.

Now let’s see how to use the dataset we just defined. First, randomly generate a 10x3-dimensional data Tensor, and then generate a 10-dimensional label Tensor corresponding to the data Tensor. Use these two Tensors to create an object of MyDataset. We can directly use the len() function to see the size of the dataset, and use indexing to retrieve the data.

# Generate data
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,)) # Labels are 0 or 1

# Wrap the data into a Dataset
my_dataset = MyDataset(data_tensor, target_tensor)

# Check the size of the dataset
print('Dataset size:', len(my_dataset))
'''
Output:
Dataset size: 10
'''

# Access data using indexing
print('tensor_data[0]: ', my_dataset[0])
'''
Output:
tensor_data[0]:  (tensor([ 0.4931, -0.0697,  0.4171]), tensor(0))
'''

DataLoader Class #

In real-world projects, if the amount of data is large and considering limited memory, I/O speed, etc., it is not possible to load all the data into memory at once during training, nor can we load it with only one process. Therefore, we need to load data in multiple processes and iterate over it. The DataLoader is designed for these needs.

DataLoader is an iterator. The most basic usage is to pass in a Dataset object, and it will generate a batch of data based on the value of the batch_size parameter, saving memory. At the same time, it can also achieve multiprocessing, data shuffling, and other processing.

The calling syntax of the DataLoader class is as follows:

from torch.utils.data import DataLoader
tensor_dataloader = DataLoader(dataset=my_dataset, # Input dataset, required parameter
                               batch_size=2,       # Batch size for output
                               shuffle=True,       # Whether to shuffle the data
                               num_workers=0)      # Number of processes, 0 means only the main process

# Output in a loop
for data, target in tensor_dataloader: 
    print(data, target)
'''
Output:
tensor([[-0.1781, -1.1019, -0.1507],
        [-0.6170,  0.2366,  0.1006]]) tensor([0, 0])
tensor([[ 0.9451, -0.4923, -1.8178],
        [-0.4046, -0.5436, -1.7911]]) tensor([0, 0])
tensor([[-0.4561, -1.2480, -0.3051],
        [-0.9738,  0.9465,  0.4812]]) tensor([1, 0])
tensor([[ 0.0260,  1.5276,  0.1687],
        [ 1.3692, -0.0170, -1.6831]]) tensor([1, 0])
tensor([[ 0.0515, -0.8892, -0.1699],
        [ 0.4931, -0.0697,  0.4171]]) tensor([1, 0])
'''
 
# Output one batch
print('One batch tensor data: ', iter(tensor_dataloader).next())
'''
Output:
One batch tensor data:  [tensor([[ 0.9451, -0.4923, -1.8178],
        [-0.4046, -0.5436, -1.7911]]), tensor([0, 0])]
'''

Combining the code, let’s summarize the parameters in DataLoader. They respectively represent:

  • dataset: Dataset type, the input dataset, a required parameter;
  • batch_size: int type, the number of samples in each batch;
  • shuffle: bool type, whether to shuffle the data at the beginning of each epoch;
  • num_workers: int type, the number of processes for loading data, 0 means all data will be loaded into the main process by default.

What is Torchvision #

Torchvision is a Python package that is used in conjunction with PyTorch. It provides not only some commonly used datasets but also several pre-built classic network models and integrated image data processing tools, mainly used for data preprocessing. In simple terms, Torchvision is a combination of common datasets + common network models + common image processing methods.

Installing Torchvision is also very simple. You can use conda to install it, with the following command:

conda install torchvision -c pytorch

Alternatively, you can use pip to install it, with the following command:

pip install torchvision

By default, Torchvision uses the PIL image loader. Therefore, to ensure smooth operation of Torchvision, we also need to install a third-party image processing library for Python called Pillow. Pillow provides extensive file format support and powerful image processing capabilities, including image storage, display, format conversion, and basic image processing operations.

To install Pillow using conda, use the following command:

conda install pillow

To install Pillow using pip, use the following command:

pip install pillow

Using Torchvision to Read Data #

After installing Torchvision, let’s take a look at what support Torchvision provides for reading data.

The torchvision.datasets module in Torchvision provides interfaces for various image datasets. Common image datasets such as MNIST and COCO are encapsulated in this module.

The table below lists all the datasets supported by the torchvision.datasets module. For detailed explanations and interfaces of each dataset, please refer to the link: https://pytorch.org/vision/stable/datasets.html.

Image

I would like to remind you that the torchvision.datasets module itself does not contain the actual dataset files. Its working mechanism is to first download the dataset files from the internet to the specified directory, and then load the dataset into memory using its loaders. Finally, the loaded dataset is returned to the user as an object.

To further deepen your understanding, let’s take the MNIST dataset as an example to demonstrate how this module is used.

Introduction to the MNIST Dataset #

The MNIST dataset is a well-known dataset of handwritten digits, and it is a classic introductory example in the field of deep learning due to its simplicity.

The MNIST dataset is a subset of the NIST dataset, and you can download the MNIST dataset here. It consists of four parts, which I have summarized in the table below.

Image

The MNIST dataset is stored in ubyte format. Let’s first parse the “Training set images” into image format to visually see what the dataset looks like. I will explain the parsing process in the data preview section later.

Image

Data Reading #

Next, let’s see how to use Torchvision to read the MNIST dataset.

For all the datasets supported by torchvision.datasets, corresponding dataset interfaces are built-in. For example, for the MNIST dataset mentioned earlier, torchvision.datasets provides an interface called MNIST, which encapsulates the entire process from downloading, decompressing, reading data, to parsing data.

These interfaces work in a similar way, which is to download the dataset from the internet to the specified directory, and then load the dataset into memory using loaders. Finally, the loaded dataset is returned to the user as an object.

Taking MNIST as an example, we can use the following code:

# Using MNIST as an example
import torchvision
mnist_dataset = torchvision.datasets.MNIST(root='./data',
                                           train=True,
                                           transform=None,
                                           target_transform=None,
                                           download=True)

torchvision.datasets.MNIST is a class. Instantiating it will return an MNIST dataset object. The constructor has 5 parameters:

  • root is a string specifying the location where you want to save the MNIST dataset. If download is set to False, it will read the dataset from the specified location.
  • download is a boolean value indicating whether to download the dataset. If set to True, it will automatically download the dataset from the internet and store it in the location specified by root. If the dataset files already exist at the specified location, it will not download again.
  • train is a boolean value indicating whether to load the training dataset. If set to True, it will only load the training data. If set to False, it will only load the test data. Please note that not all datasets are divided into training and test sets, so this parameter may not be effective. The specific usage depends on the official interface documentation.
  • transform is used to preprocess the images, such as data augmentation, normalization, rotation, or scaling. We will explain these operations in the next lesson.
  • target_transform is used to preprocess the image labels.

By running the above code, we can get the following result. From the figure, we can see that the program first downloads the MNIST dataset from the specified URL and then performs decompression and other operations. If you run the same code again, the download process will not be repeated.

Image

At this point, you might still have some questions and may be curious about what mnist_dataset is.

If you check the type of mnist_dataset using the type function, you will get torchvision.datasets.mnist.MNIST, which is a derived class of the Dataset class we introduced earlier. It means that torchvision.datasets has already written the inheritance of the Dataset class for us and encapsulated the dataset. We can use it directly.

Here we mainly used MNIST as an example to explain. The usage for other datasets is similar, but you just need to replace the class name “MNIST” with the name of other datasets when calling.

Different datasets have different data formats, but torchvision.datasets helps us parse and read various data formats, making it very convenient. For image datasets without official interfaces, we can also use torchvision.datasets.ImageFolder interface to define them ourselves. In the image classification tutorial, we will use ImageFolder for data loading. You can take a look then.

Data Preview #

After completing the data reading, we obtain the mnist_dataset, which is a packaged dataset object.

If you want to view the specific contents of mnist_dataset, you need to convert it to a list. (If IOPub data rate exceeds limit, you can load only the test dataset by setting train=False)

mnist_dataset_list = list(mnist_dataset)
print(mnist_dataset_list)

The executed result is shown in the following figure.

Image

From the result, we can see that the converted dataset object becomes a list of tuples, with each tuple containing two elements. The first element is the image data, and the second element is the label of the image.

The image data is of type PIL.Image.Image, which can be directly displayed in Jupyter Notebook. The code to display one data entry is as follows:

display(mnist_dataset_list[0][0])
print("Image label is:", mnist_dataset_list[0][1])

The result is as shown in the figure. We can see that the first data entry in the mnist_dataset is an image of the handwritten digit “7”, and its label is “7”.

Image

If you have obtained the same result as above, it means that your operations are correct. Congratulations on successfully completing the reading operation.

Summary #

Congratulations on completing this lesson. We have taken the first step in model training by learning how to read data.

The focus today was on mastering two methods of reading data, namely customizing and reading commonly used image datasets.

The most general method of data reading is to define a derived class of Dataset. To read commonly used image datasets, we can use the Torchvision package provided by PyTorch.

The Torchvision library provides rich interfaces for reading image datasets. I demonstrated how to use Torchvision to read the MNIST dataset, a classic example of handwritten digit recognition.

torchvision.datasets inherits from the Dataset class and not only provides pre-defined interfaces for many common datasets but also reserves interfaces for data preprocessing and data augmentation. In the next lesson, we will explore these data augmentation functions and learn how to perform data augmentation.

Practice for Each Lesson #

In PyTorch, which class should we inherit from when defining a dataset?

Feel free to leave a comment and interact with me in the comments section. I also recommend sharing this lesson with more friends and colleagues to learn and progress together.