Билинейный тензорный продукт в TensorFlow
Я работаю над повторной реализацией этой статьи, а ключевая операция - билинейный тензорный продукт. Я почти не знаю, что это значит, но у бумаги есть небольшая графика, которую я понимаю.
![введите описание изображения здесь]()
Операция ключа e_1 * W * e_2, и я хочу знать, как ее реализовать в тензорном потоке, потому что остальное должно быть легко.
В принципе, учитывая 3D-тензор W, разрежьте его на матрицы, а для j-го среза (матрицы) умножьте его с каждой стороны на e_1 и e_2, что приводит к скаляру, который является j-й вхождением в результирующий вектор (выход этой операции).
Итак, я хочу выполнить произведение e_1, d-мерного вектора, W, тензора dxdxk и e_2, другого d-мерный вектор. Может ли этот продукт быть кратко выражен в TensorFlow, как сейчас, или мне нужно каким-то образом определить свой собственный?
ИЗОБРАЖЕНИЯ EARLIER
Почему не происходит умножение этих тензоров, и есть ли способ определить его более явно, чтобы он работал?
>>> import tensorflow as tf
>>> tf.InteractiveSession()
>>> a = tf.ones([3, 3, 3])
>>> a.eval()
array([[[ 1., 1., 1.],
[ 1., 1., 1.],
[ 1., 1., 1.]],
[[ 1., 1., 1.],
[ 1., 1., 1.],
[ 1., 1., 1.]],
[[ 1., 1., 1.],
[ 1., 1., 1.],
[ 1., 1., 1.]]], dtype=float32)
>>> b = tf.ones([3, 1, 1])
>>> b.eval()
array([[[ 1.]],
[[ 1.]],
[[ 1.]]], dtype=float32)
>>>
Сообщение об ошибке
ValueError: Shapes TensorShape([Dimension(3), Dimension(3), Dimension(3)]) and TensorShape([Dimension(None), Dimension(None)]) must have the same rank
НАСТОЯЩЕЕ
Оказывается, что умножение двух трехмерных тензоров не работает ни с tf.matmul
, поэтому, но tf.batch_matmul
. tf.batch_matmul
также будут делать 3D-тензоры и матрицы. Затем я попробовал 3D и вектор:
ValueError: Dimensions Dimension(3) and Dimension(1) are not compatible
Ответы
Ответ 1
Вы можете сделать это с помощью простой перестройки. Для первого из двух матричных умножений у вас есть k * d, длина d векторов для точечного произведения с.
Это должно быть близко:
temp = tf.matmul(E1,tf.reshape(Wddk,[d,d*k]))
result = tf.matmul(E2,tf.reshape(temp,[d,k]))