05 Tensor Transformation Quick Mastery of Tensor Slicing, Transforming, Etc

05 Tensor Transformation Quick Mastery of Tensor Slicing, Transforming, etc #

Hello, I am Fang Yuan.

In the previous lesson, we learned the basic concepts of tensors and became familiar with operations such as creation, conversion, and dimensional transformation. With this foundation, you can perform some simple tensor-related operations.

However, in order to use tensors more flexibly in practical applications, it is also essential to know how to perform operations such as tensor concatenation and slicing. In today’s lesson, we will learn together through examples and images. Although these operations may be a bit challenging, as long as you patiently listen to my explanations and practice, you will be able to master them.

Tensor Concatenation Operations #

In project development, the data of a certain layer of neurons in deep learning may have multiple different sources, so it needs to be combined. We call this combination operation concatenation.

cat #

The concatenation operation function is as follows.

torch.cat(tensors, dim=0, out=None)

“cat” stands for “concatenate”, which means to join or connect. This function has two important parameters that you need to understand.

The first parameter is tensors, which is easy to understand. It represents several tensors that we are going to concatenate.

The second parameter is dim. Let’s recall the definition of a tensor. A tensor can have multiple dimensions. For example, if we have two 3D tensors, there can be several different ways to concatenate them (as shown in the following figure), and the dim parameter can specify the arrangement.

Image

Looking at this, you may find the figure I drew a bit confusing as it is in 3D. So let’s start with the simpler 2D case. Let’s first declare two 3x3 matrices, with the code:

>>> A = torch.ones(3, 3)
>>> B = 2 * torch.ones(3, 3)
>>> A
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
>>> B
tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]])

Let’s look at what happens when dim=0:

>>> C = torch.cat((A, B), 0)
>>> C
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]])

You will notice that the two matrices are concatenated along the “row” direction.

Now let’s see what happens when dim=1:

>>> D = torch.cat((A, B), 1)
>>> D
tensor([[1., 1., 1., 2., 2., 2.],
        [1., 1., 1., 2., 2., 2.],
        [1., 1., 1., 2., 2., 2.]])

Obviously, the two matrices are concatenated along the “column” direction.

What if the tensor is 3D or even higher-dimensional? The principle is the same - the value of dim determines the direction in which the two matrices will be concatenated.

You may ask, since cat actually concatenates multiple tensors along the existing dimensions, how can we concatenate them along a new dimension? This is where the stack function comes in.

stack #

To help you understand better, let’s look at a specific example. Suppose we have two 2D matrices as tensors and we want to “stack” them together to create a 3D tensor, as shown in the following figure:

Image

This means that the original dimension (rank) was 2 and has now become 3, creating a three-dimensional structure with an added dimension. You need to pay attention to the fact that this is different from cat. In the example shown in the illustration for cat, the original tensor is already 3D, whereas here we are adding a dimension from 2D to 3D.

In practical image algorithm development, we often need to merge multiple single-channel tensors (2D) to obtain a multi-channel result (3D). The method of concatenating by adding dimensions is called stack.

The definition of the stack function is as follows:

torch.stack(inputs, dim=0)

Here, inputs represents the tensors to be concatenated, and dim represents the direction of the new dimension.

So how do we use stack? Let’s look at an example together:

>>> A = torch.arange(0, 4)
>>> A
tensor([0, 1, 2, 3])
>>> B = torch.arange(5, 9)
>>> B
tensor([5, 6, 7, 8])
>>> C = torch.stack((A, B), 0)
>>> C
tensor([[0, 1, 2, 3],
        [5, 6, 7, 8]])
>>> D = torch.stack((A, B), 1)
>>> D
tensor([[0, 5],
        [1, 6],
        [2, 7],
        [3, 8]])

From the code, we can see that first we construct two 4-element vectors, A and B, and their dimension is 1. Then, we create a new dimension along dim=0, which is the “row” direction, resulting in C with a dimension of 2. As for D, we create a new dimension along dim=1, which is the “column” direction.

Tensor Split Operations #

After learning about concatenation operations, let’s take a look at the inverse operation of concatenation: split.

Splitting is the inverse process of concatenation. Just like concatenation, there are different types of split operations, such as slicing and splitting into chunks. Yes, split operations mainly fall into three types: chunk, split, and unbind.

At first glance, there are quite a few types of split operations, but it’s because they each have their own characteristics and are suitable for different use cases. Let’s take a closer look together.

chunk #

The purpose of chunk is to evenly divide a Tensor along the specified dimension as much as possible.

For example, if we have a feature with 32 channels and we want to evenly divide it into 4 groups with 8 channels in each group, we can achieve this splitting using the chunk function. The specific function definition is as follows:

torch.chunk(input, chunks, dim=0)

Let’s take a look at the three parameters involved in the function:

First is the input parameter, which represents the Tensor to be split.

Next, we have chunks, which represents the number of chunks to be divided, not the number of elements in each group. Please note that chunks must be an integer.

Lastly, we have dim, which represents along which dimension the split should be performed.

Just like before, let’s go through a few code examples to gain an intuitive understanding. Let’s start with a simple one-dimensional vector:

>>> A = torch.tensor([1,2,3,4,5,6,7,8,9,10])
>>> B = torch.chunk(A, 2, 0)
>>> B
(tensor([1, 2, 3, 4, 5]), tensor([6, 7, 8, 9, 10]))

Here, we use the chunk function to split the original Tensor A, which has a length of 10, into two equal-length vectors of length 5 each. (Note that B is a tuple consisting of two split results).

Now, what would happen if the chunk parameter cannot be evenly divided? Let’s continue:

>>> B = torch.chunk(A, 3, 0)
>>> B
(tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8]), tensor([9, 10]))

We can see that the Tensor A, which has a length of 10, is split into three vectors of lengths 4, 4, and 2 respectively. This is because the calculation of the number of elements in each result is done by first performing division and then rounding up to the nearest integer.

To understand this, let’s look at a slightly larger example where A is changed to a length of 17:

>>> A = torch.tensor([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17])
>>> B = torch.chunk(A, 4, 0)
>>> B
(tensor([1, 2, 3, 4, 5]), tensor([6, 7, 8, 9, 10]), tensor([11, 12, 13, 14, 15]), tensor([16, 17]))

The Tensor A with a length of 17 is split into four vectors with lengths 5, 5, 5, and 2 respectively. At this point, you will realize that the chunk function calculates the number of elements in each result by first performing division and then rounding up to the nearest integer.

What if the chunk parameter is larger than the length of the Tensor that can be split? Let’s try it out with some code:

>>> A = torch.tensor([1,2,3])
>>> B = torch.chunk(A, 5, 0)
>>> B
(tensor([1]), tensor([2]), tensor([3]))

Clearly, the split Tensor can only be divided into several vectors of length 1.

Based on this, we can infer the case of two-dimensional Tensors. Let’s take another example to see the situation of a two-dimensional matrix Tensor:

>>> A = torch.ones(4, 4)
>>> A
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
>>> B = torch.chunk(A, 2, 0)
>>> B
(tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.]]),
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.]]))

Just like with concatenation before, the dim parameter here represents the dimension along which the split is performed.

The chunk function we just introduced splits a Tensor into a predetermined number of chunks. What if we want to split the Tensor according to a specific size for each split? PyTorch also provides a corresponding function for this, called split.

split #

The function definition of split is as follows, and as before, let’s take a look at the parameters involved here:

torch.split(tensor, split_size_or_sections, dim=0)

First is tensor, which represents the tensor to be split. Next is the parameter split_size_or_sections. When it is an integer, it means the tensor will be split into blocks of this integer size. When this parameter is a list, it means the tensor will be split into blocks of the same size as the elements in the list.

Finally, we have dim, which defines the dimension along which to split.

Similarly, let’s look at some examples to see how the split function works. First, let’s consider the case when split_size_or_sections is an integer.

>>> A = torch.rand(4,4)
>>> A
tensor([[0.6418, 0.4171, 0.7372, 0.0733],
        [0.0935, 0.2372, 0.6912, 0.8677],
        [0.5263, 0.4145, 0.9292, 0.5671],
        [0.2284, 0.6938, 0.0956, 0.3823]])
>>> B = torch.split(A, 2, 0)
>>> B
(tensor([[0.6418, 0.4171, 0.7372, 0.0733],
        [0.0935, 0.2372, 0.6912, 0.8677]]), 
 tensor([[0.5263, 0.4145, 0.9292, 0.5671],
        [0.2284, 0.6938, 0.0956, 0.3823]]))

In this example, we can see that the original Tensor A of size 4x4 is split along the first dimension, which is the “row” dimension, into blocks of size 2 “rows”, resulting in two 2x4 Tensors.

Now let’s consider what happens when split_size_or_sections is not divisible by the size of the corresponding dimension. We can modify the code as follows:

>>> C = torch.split(A, 3, 0)
>>> C
(tensor([[0.6418, 0.4171, 0.7372, 0.0733],
        [0.0935, 0.2372, 0.6912, 0.8677],
        [0.5263, 0.4145, 0.9292, 0.5671]]), 
 tensor([[0.2284, 0.6938, 0.0956, 0.3823]]))

Based on the modified code, we can observe that PyTorch will try to make each result have a size equal to split_size_or_sections for the corresponding dimension. If there is not enough remaining elements, the remaining content will be put into one block as the last result.

Next, let’s look at the case when split_size_or_sections is a list. As mentioned earlier, when split_size_or_sections is a list, it means the tensor will be split into blocks of the same size as the elements in the list. Here is an example:

>>> A = torch.rand(5,4)
>>> A
tensor([[0.1005, 0.9666, 0.5322, 0.6775],
        [0.4990, 0.8725, 0.5627, 0.8360],
        [0.3427, 0.9351, 0.7291, 0.7306],
        [0.7939, 0.3007, 0.7258, 0.9482],
        [0.7249, 0.7534, 0.0027, 0.7793]])
>>> B = torch.split(A, (2,3), 0)
>>> B
(tensor([[0.1005, 0.9666, 0.5322, 0.6775],
        [0.4990, 0.8725, 0.5627, 0.8360]]), 
 tensor([[0.3427, 0.9351, 0.7291, 0.7306],
        [0.7939, 0.3007, 0.7258, 0.9482],
        [0.7249, 0.7534, 0.0027, 0.7793]]))

In this code snippet, we first create a 5x4 two-dimensional matrix Tensor A. Then, we split it along the first dimension into blocks with sizes 2 (rows) and 3 (rows), respectively.

unbind #

By learning about the previous functions, we now know how to split the tensor into fixed sizes or select based on indices. Now let’s imagine a scenario where we have a Tensor of a 3-channel image and we want to access each channel’s data one by one. How can we achieve that?

If we use chunk, we need to set the number of chunks to 3. If we use split, we need to set split_size_or_sections to 1.

Although they can achieve the same purpose, it can be cumbersome to iterate through each channel if there are a large number of channels. In this case, we need another function called unbind, which is defined as follows:

torch.unbind(input, dim=0)

Here, input represents the tensor to be processed, and dim is the same as before, indicating the direction of slicing.

Let’s understand it with an example:

>>> A = torch.arange(0,16).view(4,4)
>>> A
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
>>> b = torch.unbind(A, 0)
>>> b
(tensor([0, 1, 2, 3]), 
 tensor([4, 5, 6, 7]), 
 tensor([ 8,  9, 10, 11]), 
 tensor([12, 13, 14, 15]))

In this example, we first create a 4x4 matrix Tensor. Then, we split it along the first dimension, which is the “row” dimension. Since the matrix has 4 rows, we get 4 results.

Next, let’s see what happens if we split it along the second dimension, which is the “column” direction:

>>> b = torch.unbind(A, 1)
>>> b
(tensor([ 0,  4,  8, 12]), 
 tensor([ 1,  5,  9, 13]), 
 tensor([ 2,  6, 10, 14]), 
 tensor([ 3,  7, 11, 15]))

As can be seen, the tensor is split along the “column” direction. Therefore, unbind is a way to split the tensor and reduce its dimension, resulting in the removal of one dimension.

Indexing Operations for Tensors #

Did you notice that in the chunk and split operations we discussed earlier, we sliced the data as a whole and obtained all the results? But sometimes, we only need a part of the data. How can we achieve this? A natural idea is to directly tell the Tensor which parts we want, and this method is called indexing.

Indexing operations can be done in many ways, including pre-defined APIs and user-defined operations. Among them, the two most commonly used operations are index_select and masked_select. Let’s take a look at their usage.

index_select #

Here, we need the index_select function, defined as follows:

torch.index_select(tensor, dim, index)

The tensor and dim are the same as in the previous functions, so we won’t go into details. What we need to focus on is the index, which represents the positions to select data from the dim dimension. Here, it is important to note that index is of type torch.Tensor.

Let’s look at some examples:

>>> A = torch.arange(0, 16).view(4, 4)
>>> A
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
>>> B = torch.index_select(A, 0, torch.tensor([1, 3]))
>>> B
tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]])
>>> C = torch.index_select(A, 1, torch.tensor([0, 3]))
>>> C
tensor([[ 0,  3],
        [ 4,  7],
        [ 8, 11],
        [12, 15]])

In this example, we first create a 4x4 matrix tensor A. Then, we select the data from rows 1 and 3 using the first dimension, and obtain the final tensor B, which has a size of 2x4. Next, we select the data from columns 0 and 3 using the second dimension, and obtain the final tensor C, which has a size of 4x2.

Isn’t it simple?

masked_select #

The index_select operation we just introduced is based on given indices to extract data. But sometimes, we want to select data based on certain conditionals, such as extracting parameters greater than 0 in a specific layer of a deep learning network.

For this, we need to use the masked_select function provided by PyTorch. Let’s take a look at its definition:

torch.masked_select(input, mask, out=None)

Here, we only need to pay attention to the first two parameters, input and mask.

input represents the tensor to be processed. mask represents the mask tensor, which is a feature mask that satisfies certain conditions. Here, you should note that the mask tensor should have the same number of elements as the input tensor, but the shape or dimensions don’t need to be the same.

Feeling a bit confused? Let me give you an example that will make everything clear.

Have you ever wondered what would happen if we compare a tensor with a number? For example, in the following code, we randomly generate a 5-element tensor A:

>>> A = torch.rand(5)
>>> A
tensor([0.3731, 0.4826, 0.3579, 0.4215, 0.2285])
>>> B = A > 0.3
>>> B
tensor([ True,  True,  True,  True, False])

In this code snippet, we compare tensor A with 0.3, and obtain a new tensor B, where each element indicates whether the corresponding value in A is greater than 0.3.

For example, the first value in A is 0.3731, which is greater than 0.3, so it is True; the last value 0.2285 is less than 0.3, so it is False.

This new tensor is actually a mask tensor, where each bit represents the result of a conditional statement.

Then, let’s continue with some code to see how the selection based on the mask tensor B works:

>>> C = torch.masked_select(A, B)
>>> C
tensor([0.3731, 0.4826, 0.3579, 0.4215])

You will notice that C actually obtains the data at the corresponding positions in A where the elements in B are True.

Now, you should understand the purpose of masked_select, right? It is used to obtain a mask tensor based on the desired filtering conditions, and then use this mask tensor to extract data from the tensor.

Based on this idea, the above example can be simplified as follows:

>>> A = torch.rand(5)
>>> A
tensor([0.3731, 0.4826, 0.3579, 0.4215, 0.2285])
>>> C = torch.masked_select(A, A > 0.3)
>>> C
tensor([0.3731, 0.4826, 0.3579, 0.4215])

Isn’t it simple?

Summary #

Congratulations on completing this lesson. In this lesson, we learned more advanced operations in Tensor, including operations for combining tensors, splitting tensors, and selecting data based on indices or filtering conditions within a tensor.

When using these functions, it is crucial to pay attention to the values of boundary parameters, specifically dimensions and sizes. It is important to calculate them carefully in advance to avoid incorrect results.

With the help of numerous examples, I believe you will be able to master these operations.

To summarize and consolidate the main functions and their usage in Tensor, I have provided a table. However, there is no need to memorize these parameters. Instead, you can flexibly refer to the relevant parameter list based on your needs when using them. - image- With these two lessons, we have understood a series of operations in Tensor. In future projects, you will be able to manipulate tensors in various creative ways with ease. Keep up the good work!

Practice for Each Lesson #

Now we have a tensor as follows:

>>> A=torch.tensor([[4,5,7], [3,9,8],[2,3,4]])
>>> A
tensor([[4, 5, 7],
        [3, 9, 8],
        [2, 3, 4]])

We want to extract the first element of the first row, the first and second elements of the second row, and the last element of the third row. How do we do it?

Feel free to interact with me in the comments section, and I also recommend you to share this lesson with more colleagues and friends!