Найти позицию максимум для уникального бункера (binargmax)
Настроить
Предположим, что у меня есть
bins = np.array([0, 0, 1, 1, 2, 2, 2, 0, 1, 2])
vals = np.array([8, 7, 3, 4, 1, 2, 6, 5, 0, 9])
k = 3
Мне нужна позиция максимальных значений уникальным бункером в bins
.
# Bin == 0
# ↓ ↓ ↓
# [0 0 1 1 2 2 2 0 1 2]
# [8 7 3 4 1 2 6 5 0 9]
# ↑ ↑ ↑
# ⇧
# [0 1 2 3 4 5 6 7 8 9]
# Maximum is 8 and happens at position 0
(vals * (bins == 0)).argmax()
0
# Bin == 1
# ↓ ↓ ↓
# [0 0 1 1 2 2 2 0 1 2]
# [8 7 3 4 1 2 6 5 0 9]
# ↑ ↑ ↑
# ⇧
# [0 1 2 3 4 5 6 7 8 9]
# Maximum is 4 and happens at position 3
(vals * (bins == 1)).argmax()
3
# Bin == 2
# ↓ ↓ ↓ ↓
# [0 0 1 1 2 2 2 0 1 2]
# [8 7 3 4 1 2 6 5 0 9]
# ↑ ↑ ↑ ↑
# ⇧
# [0 1 2 3 4 5 6 7 8 9]
# Maximum is 9 and happens at position 9
(vals * (bins == 2)).argmax()
9
Эти функции хакерские и даже не обобщаемы для отрицательных значений.
Вопрос
Как получить все такие значения наиболее эффективным образом с помощью Numpy?
То, что я пробовал.
def binargmax(bins, vals, k):
out = -np.ones(k, np.int64)
trk = np.empty(k, vals.dtype)
trk.fill(np.nanmin(vals) - 1)
for i in range(len(bins)):
v = vals[i]
b = bins[i]
if v > trk[b]:
trk[b] = v
out[b] = i
return out
binargmax(bins, vals, k)
array([0, 3, 9])
СВЯЗЬ С ИСПЫТАНИЯМИ И ВАЛИДАЦИЕЙ
Ответы
Ответ 1
Здесь один из способов, компенсируя данные каждой группы, чтобы мы могли использовать argsort
для всех данных за один раз -
def binargmax_scale_sort(bins, vals):
w = np.bincount(bins)
valid_mask = w!=0
last_idx = w[valid_mask].cumsum()-1
scaled_vals = bins*(vals.max()+1) + vals
#unique_bins = np.flatnonzero(valid_mask) # if needed
return len(bins) -1 -np.argsort(scaled_vals[::-1], kind='mergesort')[last_idx]
Ответ 2
Библиотека numpy_indexed
:
Я знаю, что это не технически numpy
, но библиотека numpy_indexed
имеет векторную функцию group_by
которая идеально подходит для этого, просто хотела поделиться как альтернатива, которую я часто использую:
>>> import numpy_indexed as npi
>>> npi.group_by(bins).argmax(vals)
(array([0, 1, 2]), array([0, 3, 9], dtype=int64))
Используя простые pandas
groupby
и idxmax
:
df = pd.DataFrame({'bins': bins, 'vals': vals})
df.groupby('bins').vals.idxmax()
Использование sparse.csr_matrix
Эта опция очень быстро работает на очень больших входах.
sparse.csr_matrix(
(vals, bins, np.arange(vals.shape[0]+1)), (vals.shape[0], k)
).argmax(0)
# matrix([[0, 3, 9]])
Спектакль
функции
def chris(bins, vals, k):
return npi.group_by(bins).argmax(vals)
def chris2(df):
return df.groupby('bins').vals.idxmax()
def chris3(bins, vals, k):
sparse.csr_matrix((vals, bins, np.arange(vals.shape[0] + 1)), (vals.shape[0], k)).argmax(0)
def divakar(bins, vals, k):
mx = vals.max()+1
sidx = bins.argsort()
sb = bins[sidx]
sm = np.r_[sb[:-1] != sb[1:],True]
argmax_out = np.argsort(bins*mx + vals)[sm]
max_out = vals[argmax_out]
return max_out, argmax_out
def divakar2(bins, vals, k):
last_idx = np.bincount(bins).cumsum()-1
scaled_vals = bins*(vals.max()+1) + vals
argmax_out = np.argsort(scaled_vals)[last_idx]
max_out = vals[argmax_out]
return max_out, argmax_out
def user545424(bins, vals, k):
return np.argmax(vals*(bins == np.arange(bins.max()+1)[:,np.newaxis]),axis=-1)
def user2699(bins, vals, k):
res = []
for v in np.unique(bins):
idx = (bins==v)
r = np.where(idx)[0][np.argmax(vals[idx])]
res.append(r)
return np.array(res)
def sacul(bins, vals, k):
return np.lexsort((vals, bins))[np.append(np.diff(np.sort(bins)), 1).astype(bool)]
@njit
def piRSquared(bins, vals, k):
out = -np.ones(k, np.int64)
trk = np.empty(k, vals.dtype)
trk.fill(np.nanmin(vals))
for i in range(len(bins)):
v = vals[i]
b = bins[i]
if v > trk[b]:
trk[b] = v
out[b] = i
return out
Настроить
import numpy_indexed as npi
import numpy as np
import pandas as pd
from timeit import timeit
import matplotlib.pyplot as plt
from numba import njit
from scipy import sparse
res = pd.DataFrame(
index=['chris', 'chris2', 'chris3', 'divakar', 'divakar2', 'user545424', 'user2699', 'sacul', 'piRSquared'],
columns=[10, 50, 100, 500, 1000, 5000, 10000, 50000, 100000, 500000],
dtype=float
)
k = 5
for f in res.index:
for c in res.columns:
bins = np.random.randint(0, k, c)
k = 5
vals = np.random.rand(c)
df = pd.DataFrame({'bins': bins, 'vals': vals})
stmt = '{}(df)'.format(f) if f in {'chris2'} else '{}(bins, vals, k)'.format(f)
setp = 'from __main__ import bins, vals, k, df, {}'.format(f)
res.at[f, c] = timeit(stmt, setp, number=50)
ax = res.div(res.min()).T.plot(loglog=True)
ax.set_xlabel("N");
ax.set_ylabel("time (relative)");
plt.show()
Результаты
Результаты с гораздо большим k
(Это то, где широковещательная передача сильно ударяется):
res = pd.DataFrame(
index=['chris', 'chris2', 'chris3', 'divakar', 'divakar2', 'user545424', 'user2699', 'sacul', 'piRSquared'],
columns=[10, 50, 100, 500, 1000, 5000, 10000, 50000, 100000, 500000],
dtype=float
)
k = 500
for f in res.index:
for c in res.columns:
bins = np.random.randint(0, k, c)
vals = np.random.rand(c)
df = pd.DataFrame({'bins': bins, 'vals': vals})
stmt = '{}(df)'.format(f) if f in {'chris2'} else '{}(bins, vals, k)'.format(f)
setp = 'from __main__ import bins, vals, df, k, {}'.format(f)
res.at[f, c] = timeit(stmt, setp, number=50)
ax = res.div(res.min()).T.plot(loglog=True)
ax.set_xlabel("N");
ax.set_ylabel("time (relative)");
plt.show()
Как видно из графиков, широковещательная передача является отличным трюком, когда количество групп невелико, однако временная сложность/память вещания слишком быстро возрастает при более высоких значениях k
чтобы сделать ее высокоэффективной.
Ответ 3
Хорошо, здесь моя линейная запись времени, используя только индексирование и np.(max|min)inum.at
Он предполагает, что бункеры поднимаются от 0 до макс (бункеров).
def via_at(bins, vals):
max_vals = np.full(bins.max()+1, -np.inf)
np.maximum.at(max_vals, bins, vals)
expanded = max_vals[bins]
max_idx = np.full_like(max_vals, np.inf)
np.minimum.at(max_idx, bins, np.where(vals == expanded, np.arange(len(bins)), np.inf))
return max_vals, max_idx
Ответ 4
Как насчет этого:
>>> import numpy as np
>>> bins = np.array([0, 0, 1, 1, 2, 2, 2, 0, 1, 2])
>>> vals = np.array([8, 7, 3, 4, 1, 2, 6, 5, 0, 9])
>>> k = 3
>>> np.argmax(vals*(bins == np.arange(k)[:,np.newaxis]),axis=-1)
array([0, 3, 9])
Ответ 5
Если вы собираетесь читать, это может быть не лучшее решение, но я думаю, что это работает
def binargsort(bins,vals):
s = np.lexsort((vals,bins))
s2 = np.sort(bins)
msk = np.roll(s2,-1) != s2
# or use this for msk, but not noticeably better for performance:
# msk = np.append(np.diff(np.sort(bins)),1).astype(bool)
return s[msk]
array([0, 3, 9])
Объяснение:
lexsort
сортирует индексы vals
согласно отсортированному порядку bins
, затем по порядку vals
:
>>> np.lexsort((vals,bins))
array([7, 1, 0, 8, 2, 3, 4, 5, 6, 9])
Таким образом, вы можете замаскировать, чтобы отсортированные bins
отличались от одного индекса к другому:
>>> np.sort(bins)
array([0, 0, 0, 1, 1, 1, 2, 2, 2, 2])
# Find where sorted bins end, use that as your mask on the 'lexsort'
>>> np.append(np.diff(np.sort(bins)),1)
array([0, 0, 1, 0, 0, 1, 0, 0, 0, 1])
>>> np.lexsort((vals,bins))[np.append(np.diff(np.sort(bins)),1).astype(bool)]
array([0, 3, 9])
Ответ 6
Это интересная небольшая проблема для решения. Мой подход заключается в том, чтобы получить индекс в vals
на основе значений в bins
. Используя, where
для получения точек, где индекс равен True
в сочетании с argmax
эти точки в vals дают полученное значение.
def binargmaxA(bins, vals):
res = []
for v in unique(bins):
idx = (bins==v)
r = where(idx)[0][argmax(vals[idx])]
res.append(r)
return array(res)
Можно удалить вызов до unique
, используя range(k)
чтобы получить возможные значения бинов. Это ускоряет работу, но все равно оставляет ее с низкой производительностью по мере увеличения размера k.
def binargmaxA2(bins, vals, k):
res = []
for v in range(k):
idx = (bins==v)
r = where(idx)[0][argmax(vals[idx])]
res.append(r)
return array(res)
Последняя попытка, сравнивая каждое значение, существенно замедляет работу. Эта версия вычисляет отсортированный массив значений, а не делает сравнение для каждого уникального значения. Ну, он на самом деле вычисляет отсортированные индексы и получает только отсортированные значения, когда это необходимо, поскольку это позволяет избежать однократной загрузки vals в память. Производительность по-прежнему масштабируется с количеством ящиков, но намного медленнее, чем раньше.
def binargmaxB(bins, vals):
idx = argsort(bins) # Find sorted indices
split = r_[0, where(diff(bins[idx]))[0]+1, len(bins)] # Compute where values start in sorted array
newmax = [argmax(vals[idx[i1:i2]]) for i1, i2 in zip(split, split[1:])] # Find max for each value in sorted array
return idx[newmax +split[:-1]] # Convert to indices in unsorted array
Ориентиры
Вот некоторые тесты с другими ответами.
3000 элементов
С несколько большим набором данных (bins = randint(0, 30, 3000); vals = randn(3000)
; k = 30;)
- 171us binargmax_scale_sort2 от Divakar
- Ответ на этот вопрос, версия B
- 281us binargmax_scale_sort от Divakar
- 329us широковещательная версия от пользователя545424
- Ответ на этот вопрос, версия A
- 416us ответ сакул, используя lexsort
- Код ссылки 899us by piRsquared
30000 элементов
И еще больший набор данных (bins = randint(0, 30, 30000); vals = randn(30000)
, k = 30). Удивительно, но это не меняет относительной производительности между решениями.
- 1.27 мс этот ответ, версия B
- 2.01ms binargmax_scale_sort2 от Divakar
- 2.38ms широковещательная версия от пользователя545424
- 2.68ms этот ответ, версия A
- 5.71 мс ответ by sacul, используя lexsort
- Код ссылки 9.12ms с помощью piRSquared
Редактировать Я не изменил k
с увеличением числа возможных значений бинов, теперь, когда я исправил, что эталонные тесты более четкие.
1000 значений бункера
Увеличение числа уникальных значений бинов также может повлиять на производительность. Решения Дивакара и Сакула в основном не затрагиваются, в то время как другие оказывают существенное влияние. bins = randint(0, 1000, 30000); vals = randn(30000); k = 1000
- 1.99ms binargmax_scale_sort2 от Divakar
- 3.48 м. Этот ответ, версия B
- 6.15 мс ответ by sacul, используя lexsort
- Код ссылки 10.6ms с помощью piRsquared
- 27.2 м. Этот ответ, версия А
- Версия широковещательной передачи 129ms by user545424
Редактировать Включая контрольные показатели для кода ссылки в вопросе, это удивительно конкурентоспособно, особенно с большим количеством бункеров.
Ответ 7
Я знаю, что вы сказали использовать Numpy, но если Pandas приемлема:
import numpy as np; import pandas as pd;
(pd.DataFrame(
{'bins':np.array([0, 0, 1, 1, 2, 2, 2, 0, 1, 2]),
'values':np.array([8, 7, 3, 4, 1, 2, 6, 5, 0, 9])})
.groupby('bins')
.idxmax())
values
bins
0 0
1 3
2 9