Как Pytorch Tensor получает индекс определенной стоимости
В списке python можно использовать list.index(somevalue)
. Как может pytorch сделать это?
Например:
a=[1,2,3]
print(a.index(2))
Тогда, 1
будет выводиться. Как может тензор pytorch сделать это без преобразования его в список python?
Ответы
Ответ 1
Я думаю, что нет прямого перевода из list.index()
в функцию pytorch. Однако вы можете добиться аналогичных результатов, используя tensor==number
а затем nonzero()
функцию. Например:
t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero())
Этот кусок кода возвращается
1
[torch.LongTensor размера 1x1]
Ответ 2
Может быть сделано путем преобразования в numpy следующим образом
import torch
x = torch.range(1,4)
print(x)
===> tensor([ 1., 2., 3., 4.])
nx = x.numpy()
np.where(nx == 3)[0][0]
===> 2
Ответ 3
import torch
x_data = variable(torch.Tensor([[1.0], [2.0], [3.0]]))
print(x_data.data[0])
>>tensor([1.])
Ответ 4
Для тензоров с плавающей точкой я использую это, чтобы получить индекс элемента в тензоре.
print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())
Здесь я хочу получить индекс max_value в тензоре с плавающей запятой, вы также можете указать свое значение таким образом, чтобы получить индекс любых элементов в тензоре.
print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())