Почему NumPy не замыкается на несмежные массивы?

Рассмотрим следующий простой тест:

import numpy as np
from timeit import timeit

a = np.random.randint(0,2,1000000,bool)

Давайте найдем индекс первого True

timeit(lambda:a.argmax(), number=1000)
# 0.000451055821031332

Это достаточно быстро из-за короткого замыкания numpy.

Он также работает на смежных ломтиках,

timeit(lambda:a[1:-1].argmax(), number=1000)
# 0.0006490410305559635

Но, похоже, не на несмежных. Я был в основном заинтересован в поиске последнего True:

timeit(lambda:a[::-1].argmax(), number=1000)
# 0.3737605109345168

ОБНОВЛЕНИЕ: Мое предположение, что наблюдаемое замедление было вызвано не коротким замыканием, является неточным (спасибо @Victor Ruiz). Действительно, в наихудший сценарий массива all False

b=np.zeros_like(a)
timeit(lambda:b.argmax(), number=1000)
# 0.04321779008023441

мы все еще на порядок быстрее, чем в несмежных кейс. Я готов принять объяснение Виктора, что настоящий виновник выполняется копия (время принудительного копирования с помощью .copy() наводит на мысль). После этого уже не имеет значения, происходит короткое замыкание или нет.

Но другие размеры шагов! = 1 приводят к аналогичному поведению.

timeit(lambda:a[::2].argmax(), number=1000)
# 0.19192566303536296

Вопрос: почему numpy не закорачивает ОБНОВЛЕНИЕ без копирования в последних двух примерах?

И, что еще более важно: есть ли обходной путь, то есть какой-то способ заставить numpy выполнить короткое замыкание ОБНОВЛЕНИЕ, не копируя также на несмежные массивы?

Ответы

Ответ 1

Я заинтересовался решением этой проблемы. Итак, я пришел к следующему решению, которое позволяет избежать проблемы "a[::-1]" из-за внутренних копий ndarray np.argmax:

Я создал небольшую библиотеку, которая реализует функцию argmax, которая является оберткой для np.argmax, но она повышает производительность, когда входным аргументом является одномерный логический массив со значением шага, установленным на -1:

https://github.com/Vykstorm/numpy-bool-argmax-ext

В этих случаях она использует низкоуровневую процедуру C для поиска индекса k элемента с максимальным значением (True), начиная с конца и до начала массива a.
Затем вы можете вычислить argmax(a[::-1]) с помощью len(a)-k-1

Низкоуровневый метод не выполняет никаких внутренних копий ndarray, потому что он работает с массивом a, который уже является C-смежным и выровнен в памяти. Это также относится к короткому замыканию


ОБНОВЛЕНИЕ: Я расширил библиотеку, чтобы улучшить производительность argmax также при работе со значениями шага, отличными от -1 (с булевыми массивами 1D) с хорошими результатами: a[::2], a[::-3], e.t.c.

Попробуйте.

Ответ 2

Проблема связана с выравниванием памяти массива при использовании шагов. Либо a[1:-1], a[::-1] считаются выровненными в памяти, но a[::2] DonT:

a = np.random.randint(0,2,1000000,bool)

print(a[1:-1].flags.c_contiguous) # True
print(a[::-1].flags.c_contiguous) # False
print(a[::2].flags.c_contiguous) # False

Это объясняет, почему np.argmax работает медленно на a[::2] (из документации на ndarrays):

Несколько алгоритмов в NumPy работают с произвольно пошаговыми массивами. Однако некоторые алгоритмы требуют односегментных массивов. Когда в такие алгоритмы передается массив с нерегулярным шагом, копия создается автоматически.

np.argmax(a[::2]) делает копию массива. Поэтому, если вы делаете timeit(lambda: np.argmax(a[::2]), number=5000), вы синхронизируете 5000 копий массива a

Выполните это и сравните результаты этих двух вызовов синхронизации:

print(timeit(lambda: np.argmax(a[::2]), number=5000))

b = a[::2].copy()
print(timeit(lambda: np.argmax(b), number=5000))

ОБНОВЛЕНИЕ: Заглянув в исходный код на языке C numpy, я обнаружил подчеркнутую реализацию функции argmax, PyArray_ArgMax, которая в какой-то момент вызывает PyArray_ContiguousFromAny, чтобы убедиться, что заданный входной массив выровнен в памяти (в стиле C)

Затем, если d-тип массива - bool, он делегирует функции BOOL_argmax. Глядя на его код, кажется, что короткое замыкание всегда применяется.

Резюме

  • Чтобы избежать копирования с помощью np.argmax, убедитесь, что входной массив непрерывен в памяти
  • Короткое замыкание всегда применяется, когда тип данных является логическим.