Tensor
gather

Tensor.gather()

Returns the elements of the input tensor in the dimension dim at the index index. The output tensor is the same size as the input tensor except for the dimension dim which has size len(index).

Usage

from tinygrad.tensor import Tensor
 
a1 = Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
a2 = Tensor([[0, 1], [1, 0]])
 
tensor = a1.gather(a2, 1)
 
print(tensor.numpy())

Return value

[[1. 2.]
 [4. 3.]]
;