arrow_back

## PyTorch scatter_ Function Explained

Posted on 2020-10-01

In PyTorch, `self.scatter_(dim, index, src)` is a function you can use to write the values in tensor `src` into the `self` tensor. The best way I found to think about this function is that PyTorch will iterate through the `src` tensor: for every element, it looks at its corresponding element in `index` (which should be a tensor of the same dimension as `src`) to see where in `self` to put this element. Note that the element in `index` only gives the dimension as specified by `dim`, the remaining dimensions (not specified by `dim`) are given by the index of the `src` element.

If this feels confusing, no worries. Here is an example as given by the official documentation that I will explain in depth.

``````>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
[ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
[ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
[ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])``````

Note here `dim` is 0, `index` is `torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])`, `src` tensor is `x`, which is of shape 2 * 5.

PyTorch iterates over the `src` tensor (`x`) to write its values into the `self` tensor (which is a 3 * 5 tensor filled with 0's). At the first step, `src`, the value is `0.3392`, and PyTorch needs to find where (which index) in `self` to write this value.

To do that, it first looks at the corresponding `index` to see the index specified for dimension 0 of the index for `self` (recall the input `dim` is 0), and `index` is 0 in this case, so it knows it's something like `self[some_index]`; then it fills the remaining dimensions with the index of the `src` element it is currently iterating over. Therefore in this case, it will write to `self` in the end (again, the dimension 0 is 0 because it's specified by `index`, while dimension 1 is 0 because dimension 1 of `src is 0`).

At the second step `src`, the value is `0.9006`, it first looks at the corresponding `index` to see that the index specified for dimension 0 is 1, then it fills the remaining dimensions with the index of `src`; therefore, the value `0.9006` will be written to `self` (note again that the dimension 0 is 1 because it's specified by the value of `index` and the dimension 1 is 1 because the dimension 1 of `src` is 1). You can try the subsequent iterations by yourself.

In general, if you're dealing with 2 dimensions, the rule is `self[index[i][j]][j] = src[i][j]`. Hope this clears up the concept a little - if you're still confused, refer to the official docs again and try to write down with pen and paper a few examples for yourself!