Разъяснение по tf.Tensor.set_shape()

У меня есть изображение размером 478 x 717 x 3 = 1028178 пикселей с ранга 1. Я проверил его, вызвав tf.shape и tf.rank.

Когда я вызываю image.set_shape ([478, 717, 3]), он выдает следующую ошибку.

"Shapes %s and %s must have the same rank" % (self, other)) 
ValueError: Shapes (?,) and (478, 717, 3) must have the same rank

Я проверил снова, выполнив первое задание до 1028178, но ошибка все еще существует.

ValueError: Shapes (1028178,) and (478, 717, 3) must have the same rank

Ну, это имеет смысл, потому что один из рангов 1, а другой - ранга 3. Однако зачем нужно бросать ошибку, так как общее количество пикселей все еще соответствует.

Я мог бы, конечно, использовать tf.resape, и это работает, но я думаю, что это не оптимально.

Как указано в FAQ TensorFlow

В чем разница между x.set_shape() и x = tf.reshape(x)?

Метод tf.Tensor.set_shape() обновляет статическую форму объекта Tensor, и он обычно используется для предоставления дополнительной информации о форме, когда это невозможно определить непосредственно. Он не меняет динамическую форму тензора.

Операция tf.reshape() создает новый тензор с другой динамической формой.

Создание нового тензора связано с распределением памяти и потенциально может быть более дорогостоящим, когда задействованы дополнительные примеры обучения. Это по дизайну, или я что-то пропустил?

Ответы

Ответ 1

Насколько я знаю (и я написал этот код), в Tensor.set_shape() нет ошибки. Я думаю, что недоразумение проистекает из запутанного имени этого метода.

Чтобы подробно Tensor.set_shape() вами часто задаваемую информацию, Tensor.set_shape() - это функция pure-Python, которая улучшает информацию о форме для данного объекта tf.Tensor. "Улучшает", я имею в виду "делает более конкретным".

Поэтому, когда у вас есть Tensor объект t с формой (?,), Это одномерный тензор неизвестной длины. Вы можете вызвать t.set_shape((1028178,)), а затем t будет иметь форму (1028178,) когда вы вызываете t.get_shape(). Это не влияет на базовое хранилище или вообще что-либо на бэкэнд: это просто означает, что последующий вывод формы с использованием t может основываться на утверждении, что это вектор длины 1028178.

Если t имеет форму (?,), Вызов t.set_shape((478, 717, 3)) потерпит неудачу, потому что TensorFlow уже знает, что t является вектором, поэтому он не может иметь форму (478, 717, 3). Если вы хотите создать новый тензор с этой формой из содержимого t, вы можете использовать reshaped_t = tf.reshape(t, (478, 717, 3)). Это создает новый объект tf.Tensor в Python; фактическая реализация tf.reshape() делает это с использованием мелкой копии тензорного буфера, поэтому на практике это недорого.

Одна из аналогий заключается в том, что Tensor.set_shape() похож на время выполнения на объектно-ориентированном языке, таком как Java. Например, если у вас есть указатель на Object но знаете, что на самом деле это String, вы можете сделать (String) obj cast (String) obj, чтобы передать obj методу, который ожидает аргумент String. Однако, если у вас есть String s и попытайтесь передать его java.util.Vector, компилятор даст вам ошибку, потому что эти два типа не связаны.