Как инвертировать массив перестановок в numpy
Учитывая самоиндексирующий (не уверенный, является ли это правильный термин) массив numpy, например:
a = np.array([3, 2, 0, 1])
Это означает, что permutation (=>
- стрелка):
0 => 3
1 => 2
2 => 0
3 => 1
Я пытаюсь сделать массив, представляющий обратное преобразование, не делая его "вручную" в python, то есть я хочу чистое решение numpy. Результат, который я хочу в приведенном выше случае:
array([2, 3, 1, 0])
Что эквивалентно
0 <= 3 0 => 2
1 <= 2 or 1 => 3
2 <= 0 2 => 1
3 <= 1 3 => 0
Кажется, это так просто, но я просто не могу придумать, как это сделать. Я пробовал поиск в Интернете, но не нашел ничего подходящего.
Ответы
Ответ 1
Обратным к перестановке p
of np.arange(n)
является массив индексов s
, сортирующий p
, т.е.
p[s] == np.arange(n)
должно быть истинным. Такой s
- это именно то, что возвращает np.argsort
:
>>> p = np.array([3, 2, 0, 1])
>>> np.argsort(p)
array([2, 3, 1, 0])
>>> p[np.argsort(p)]
array([0, 1, 2, 3])
Ответ 2
Сортировка здесь является излишним. Это всего лишь однопроходный, линейный алгоритм времени с постоянной потребностью в памяти:
from __future__ import print_function
import numpy as np
p = np.array([3, 2, 0, 1])
s = np.empty(p.size, dtype=np.int32)
for i in np.arange(p.size):
s[p[i]] = i
print( =', s)
Вышеприведенный код печатает
s = [2 3 1 0]
по мере необходимости.
Остальная часть ответа связана с эффективной векторизацией вышеуказанного цикла for
. Если вы просто хотите узнать решение, перейдите к концу этого ответа.
(Исходный ответ от 27 августа 2014 г., тайминги действительны для NumPy 1.8. Обновление с NumPy 1.11 следует позже).
Ожидается, что однопроходный, линейный алгоритм времени будет быстрее, чем np.argsort
; Интересно, что тривиальная векторизация (s[p] = xrange(p.size)
, см. массивы индексов) выше цикла for
на самом деле немного медленнее, чем np.argsort
, как долго как p.size < 700 000
(ну, на моей машине, ваш пробег будет меняться):
import numpy as np
def np_argsort(p):
return np.argsort(p)
def np_fancy(p):
s = np.zeros(p.size, p.dtype) # np.zeros is better than np.empty here, at least on Linux
s[p] = xrange(p.size)
return s
def create_input(n):
np.random.seed(31)
indices = np.arange(n, dtype = np.int32)
return np.random.permutation(indices)
Из моего ноутбука IPython:
p = create_input(700000)
%timeit np_argsort(p)
10 loops, best of 3: 72.7 ms per loop
%timeit np_fancy(p)
10 loops, best of 3: 70.2 ms per loop
В конечном итоге асимптотическая сложность срабатывает (O(n log n)
для argsort
vs. O(n)
для однопроходного алгоритма), а однопроходный алгоритм будет последовательно быстрее после достаточно большого n = p.size
(порог вокруг 700k на моей машине).
Однако существует менее простой способ векторизации вышеуказанного цикла for
с np.put
:
def np_put(p):
n = p.size
s = np.zeros(n, dtype = np.int32)
i = np.arange(n, dtype = np.int32)
np.put(s, p, i) # s[p[i]] = i
return s
Что дает для n = 700 000
(такого же размера, как указано выше):
p = create_input(700000)
%timeit np_put(p)
100 loops, best of 3: 12.8 ms per loop
Это хорошая скорость 5.6x для почти ничего!
Чтобы быть справедливым, np.argsort
по-прежнему превосходит подход np.put
для меньшего n
(точка опрокидывания вокруг n = 1210
на моей машине):
p = create_input(1210)
%timeit np_argsort(p)
10000 loops, best of 3: 25.1 µs per loop
%timeit np_fancy(p)
10000 loops, best of 3: 118 µs per loop
%timeit np_put(p)
10000 loops, best of 3: 25 µs per loop
Это, скорее всего, потому, что мы выделяем и заполняем дополнительный массив (при вызове np.arange()
) с помощью подхода np_put
.
Хотя вы не просили решения Cython, просто из любопытства, я также приурочил следующее решение Cython с типизированными просмотрами памяти:
import numpy as np
cimport numpy as np
def in_cython(np.ndarray[np.int32_t] p):
cdef int i
cdef int[:] pmv
cdef int[:] smv
pmv = p
s = np.empty(p.size, dtype=np.int32)
smv = s
for i in xrange(p.size):
smv[pmv[i]] = i
return s
Тайминги:
p = create_input(700000)
%timeit in_cython(p)
100 loops, best of 3: 2.59 ms per loop
Итак, решение np.put
все еще не так быстро (за 12,8 мс для этого размера ввода, а argsort - 72,7 мс).
Обновление 3 февраля 2017 года с помощью NumPy 1.11
Джейми, Андрис и Пол отметили в комментариях ниже, что проблема производительности с фантастическим индексированием была решена. Джейми говорит, что он уже разрешен в NumPy 1.9. Я тестировал его с помощью Python 3.5 и NumPy 1.11 на машине, которую я использовал в 2014 году.
def invert_permutation(p):
s = np.empty(p.size, p.dtype)
s[p] = np.arange(p.size)
return s
Тайминги:
p = create_input(880)
%timeit np_argsort(p)
100000 loops, best of 3: 11.6 µs per loop
%timeit invert_permutation(p)
100000 loops, best of 3: 11.5 µs per loop
Значительное улучшение действительно!
Заключение
В общем, я бы пошел с
def invert_permutation(p):
'''The argument p is assumed to be some permutation of 0, 1, ..., len(p)-1.
Returns an array s, where s[i] gives the index of i in p.
'''
s = np.empty(p.size, p.dtype)
s[p] = np.arange(p.size)
return s
подход для четкости кода. На мой взгляд, он менее неясен, чем argsort
, а также быстрее для больших размеров ввода. Если скорость становится проблемой, я бы пошел с решением Cython.
Ответ 3
Я хотел бы предложить немного больше фона для правильного ответа larsmans. Причина, по которой argsort
верна, может быть найдена при использовании представления перестановки по матрице. Математическим преимуществом матрицы перестановок P
является то, что матрица "работает на векторах", т.е. Матрица перестановок, когда вектор переставляет вектор.
Ваша перестановка выглядит так:
import numpy as np
a = np.array([3,2,0,1])
N = a.size
rows = np.arange(N)
P = np.zeros((N,N),dtype=int)
P[rows,a] = 1
[[0 0 0 1]
[0 0 1 0]
[1 0 0 0]
[0 1 0 0]]
Учитывая матрицу перестановок, мы можем "отменить" умножение путем умножения на нее инверсного P^-1
. Красота матриц перестановок состоит в том, что они ортогональны, поэтому P*P^(-1)=I
, или, другими словами, P(-1)=P^T
, обратная является транспонированной. Это означает, что мы можем взять индексы матрицы транспонирования, чтобы найти ваш инвертированный вектор перестановки:
inv_a = np.where(P.T)[1]
[2 3 1 0]
Что, если вы думаете об этом, точно так же, как поиск индексов, сортирующих столбцы P
!