Superior Choice from Tensors in Pytorch | by Oliver S | Feb, 2024


Utilizing torch.index_select, torch.collect and torch.take

In some conditions, you’ll have to do some superior indexing / choice with Pytorch, e.g. reply the query: “how can I choose parts from Tensor A following the indices laid out in Tensor B?”

On this put up we’ll current the three most typical strategies for such duties, specifically torch.index_select, torch.gather and torch.take. We’ll clarify all of them intimately and distinction them with each other.

Foto von Jerin J auf Unsplash

Admittedly, one motivation for this put up was me forgetting how and when to make use of which operate, ending up googling, shopping Stack Overflow and the, in my view, comparatively temporary and never too useful official documentation. Thus, as talked about, we right here do a deep dive into these capabilities: we encourage when to make use of which, give examples in 2- and 3D, and present the ensuing choice graphically.

I hope this put up will convey readability about mentioned capabilities and take away the necessity for additional exploration — thanks for studying!

And now, with out additional ado, let’s dive into the capabilities one after the other. For all, we first begin with a 2D instance and visualize the ensuing choice, after which transfer to considerably extra advanced instance in 3D. Additional, we re-implement the executed operation in easy Python — s.t. you’ll be able to have a look at pseudocode as one other supply of knowledge what these capabilities do. In the long run, we summarize the capabilities and their variations in a desk.

torch.index_select selects parts alongside one dimension, whereas holding the opposite ones unchanged. That’s: maintain all parts from all different dimensions, however choose parts within the goal dimensions following the index tensor. Let’s reveal this with a 2D instance, by which we choose alongside dimension 1:

num_picks = 2

values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, measurement=(num_picks,))
# [len_dim_0, num_picks]
picked = torch.index_select(values, 1, indices)

The ensuing tensor has form [len_dim_0, num_picks]: for each factor alongside dimension 0, now we have picked the identical factor from dimension 1. Let’s visualize this:

Leave a Reply

Your email address will not be published. Required fields are marked *