Pytorch softmax: Какое измерение использовать?
Функция torch.nn.functional.softmax
принимает два параметра: input
и dim
. Согласно своей документации операция softmax применяется ко всем срезам input
вдоль заданного значения dim
и будет масштабировать их так, чтобы элементы находились в диапазоне (0, 1)
и суммировались с 1.
Пусть вход будет:
input = torch.randn((3, 4, 5, 6))
Предположим, я хочу следующее, так что каждая запись в этом массиве равна 1:
sum = torch.sum(input, dim = 3) # sum size is (3, 4, 5, 1)
Как следует применять softmax?
softmax(input, dim = 0) # Way Number 0
softmax(input, dim = 1) # Way Number 1
softmax(input, dim = 2) # Way Number 2
softmax(input, dim = 3) # Way Number 3
Моя интуиция говорит мне, что это последняя, но я не уверен. Английский не является моим родным языком, и использование слова along
, казалось запутанным мне из - за этого.
Я не совсем понимаю, что означает "вдоль", поэтому я буду использовать пример, который мог бы прояснить ситуацию. Предположим, что у нас есть тензор размера (s1, s2, s3, s4), и я хочу, чтобы это произошло
Ответы
Ответ 1
Самый простой способ заставить вас понять: скажем, вам дается тензор формы (s1, s2, s3, s4)
и, как вы упомянули, вы хотите, чтобы сумма всех записей по последней оси была равна 1.
sum = torch.sum(input, dim = 3) # input is of shape (s1, s2, s3, s4)
Затем вы должны вызвать softmax как:
softmax(input, dim = 3)
Чтобы легко понять, вы можете рассматривать 4d-тензор формы (s1, s2, s3, s4)
как тензор 2d или матрицу формы (s1*s2*s3, s4)
. Теперь, если вы хотите, чтобы матрица содержала значения в каждой строке (ось = 0) или столбец (ось = 1), которые суммируются до 1, вы можете просто вызвать функцию softmax
на тензоре 2d следующим образом:
softmax(input, dim = 0) # normalizes values along axis 0
softmax(input, dim = 1) # normalizes values along axis 1
Вы можете увидеть пример, который Стивен упомянул в своем ответе.
Ответ 2
Стивен ответ выше не является правильным. Смотрите снимок ниже. Это на самом деле обратный путь.
Ответ 3
Я не уверен на 100%, что означает ваш вопрос, но я думаю, что ваша путаница заключается просто в том, что вы не понимаете, что означает параметр dim
. Поэтому я объясню это и приведу примеры.
Если у нас есть:
m0 = nn.Softmax(dim=0)
это означает, что m0
нормализует элементы по нулевой координате полученного тензора. Формально, если задан тензор b
размера скажем (d0,d1)
тогда будет верно следующее:
sum^{d0}_{i0=1} b[i0,i1] = 1, forall i1 \in {0,...,d1}
Вы можете легко проверить это на примере Pytorch:
>>> b = torch.arange(0,4,1.0).view(-1,2)
>>> b
tensor([[0., 1.],
[2., 3.]])
>>> m0 = nn.Softmax(dim=0)
>>> b0 = m0(b)
>>> b0
tensor([[0.1192, 0.1192],
[0.8808, 0.8808]])
теперь, поскольку dim=0
означает прохождение i0 \in {0,1}
(т.е. прохождение по строкам), если мы выберем любой столбец i1
и суммируем его элементы (т.е. строки), то мы должны получить 1. Проверьте это:
>>> b0[:,0].sum()
tensor(1.0000)
>>> b0[:,1].sum()
tensor(1.0000)
как и ожидалось.
Обратите внимание, что мы получаем все строки с суммой 1, "суммируя строки" с помощью torch.sum(b0,dim=0)
, проверьте это:
>>> torch.sum(b0,0)
tensor([1.0000, 1.0000])
Мы можем создать более сложный пример, чтобы убедиться, что он действительно понятен.
a = torch.arange(0,24,1.0).view(-1,3,4)
>>> a
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]])
>>> a0 = m0(a)
>>> a0[:,0,0].sum()
tensor(1.0000)
>>> a0[:,1,0].sum()
tensor(1.0000)
>>> a0[:,2,0].sum()
tensor(1.0000)
>>> a0[:,1,0].sum()
tensor(1.0000)
>>> a0[:,1,1].sum()
tensor(1.0000)
>>> a0[:,2,3].sum()
tensor(1.0000)
как мы и ожидали, если мы сложим все элементы по первой координате от первого значения до последнего значения, которое мы получим 1. Таким образом, все нормализуется по первому измерению (или по первой координате i0
).
>>> torch.sum(a0,0)
tensor([[1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000]])
Ответ 4
Рассмотрим пример в двух измерениях
x = [[1,2],
[3,4]]
вы хотите, чтобы ваш конечный результат был
y = [[0.27,0.73],
[0.27,0.73]]
или же
y = [[0.12,0.12],
[0.88,0.88]]
Если это первый вариант, тогда вы хотите, чтобы dim = 1. Если это второй вариант, вы хотите, чтобы dim = 0.
Обратите внимание, что столбцы или нулевая размерность нормированы во втором примере, поэтому она нормирована вдоль нулевой размерности.
Обновлено 2018-07-10: чтобы отразить, что нулевое измерение относится к столбцам в pytorch.