Как заменить только первые n элементов в массиве numpy, которые больше определенного значения?
У меня есть массив myA
следующим образом:
array([ 7, 4, 5, 8, 3, 10])
Если я хочу заменить все значения, превышающие значение val
на 0, я могу просто сделать:
myA[myA > val] = 0
который дает мне желаемый результат (для val = 5
):
array([0, 4, 5, 0, 3, 0])
Однако моя цель - заменить не все, а только первые n
элементы этого массива, которые больше значения val
.
Итак, если n = 2
мой желаемый результат будет выглядеть следующим образом (10
является третьим элементом и поэтому не должен быть заменен):
array([ 0, 4, 5, 0, 3, 10])
Прямая реализация:
import numpy as np
myA = np.array([7, 4, 5, 8, 3, 10])
n = 2
val = 5
# track the number of replacements
repl = 0
for ind, vali in enumerate(myA):
if vali > val:
myA[ind] = 0
repl += 1
if repl == n:
break
Это работает, но, может быть, кто-то может справиться с умным способом маскировки!?
Ответы
Ответ 1
Следующее должно работать:
myA[(myA > val).nonzero()[0][:2]] = 0
поскольку nonzero вернет индексы, в которых булевский массив myA > val
не равен нулю, например. True
.
Например:
In [1]: myA = array([ 7, 4, 5, 8, 3, 10])
In [2]: myA[(myA > 5).nonzero()[0][:2]] = 0
In [3]: myA
Out[3]: array([ 0, 4, 5, 0, 3, 10])
Ответ 2
Окончательное решение очень просто:
import numpy as np
myA = np.array([7, 4, 5, 8, 3, 10])
n = 2
val = 5
myA[np.where(myA > val)[0][:n]] = 0
print(myA)
Вывод:
[ 0 4 5 0 3 10]
Ответ 3
Здесь другая возможность (непроверенная), вероятно, не лучше nonzero
:
def truncate_mask(m, stop):
m = m.astype(bool, copy=False) # if we allow non-bool m, the next line becomes nonsense
return m & (np.cumsum(m) <= stop)
myA[truncate_mask(myA > val, n)] = 0
Избегая создания и использования явного индекса, вы можете получить чуть более высокую производительность... но вам придется проверить его, чтобы узнать.
Отредактируйте 1:, пока мы находимся на предмет возможностей, вы также можете попробовать:
def truncate_mask(m, stop):
m = m.astype(bool, copy=True) # note we need to copy m here to safely modify it
m[np.searchsorted(np.cumsum(m), stop):] = 0
return m
Изменить 2 (на следующий день): Я только что проверил это, и кажется, что cumsum
на самом деле хуже, чем nonzero
, по крайней мере, с типы значений Я использовал (поэтому ни один из вышеперечисленных подходов не стоит использовать). Из любопытства я также попробовал его с numba:
import numba
@numba.jit
def set_first_n_gt_thresh(a, val, thresh, n):
ii = 0
while n>0 and ii < len(a):
if a[ii] > thresh:
a[ii] = val
n -= 1
ii += 1
Это только итерация по массиву один раз, или, скорее, она только итерации над необходимой частью массива один раз, даже не касаясь последней части. Это дает вам превосходную производительность для небольших n
, но даже в худшем случае n>=len(a)
этот подход выполняется быстрее.
Ответ 4
Вы можете использовать то же решение, что и здесь, преобразовывая вас np.array
в pd.Series
:
s = pd.Series([ 7, 4, 5, 8, 3, 10])
n = 2
m = 5
s[s[s>m].iloc[:n].index] = 0
In [416]: s
Out[416]:
0 0
1 4
2 5
3 0
4 3
5 10
dtype: int64
Пошаговое объяснение:
In [426]: s > m
Out[426]:
0 True
1 False
2 False
3 True
4 False
5 True
dtype: bool
In [428]: s[s>m].iloc[:n]
Out[428]:
0 7
3 8
dtype: int64
In [429]: s[s>m].iloc[:n].index
Out[429]: Int64Index([0, 3], dtype='int64')
In [430]: s[s[s>m].iloc[:n].index]
Out[430]:
0 7
3 8
dtype: int64
Вывод в In[430]
выглядит так же, как In[428]
, но в 428 это копия и в 430 оригинальных сериях.
Если вам понадобится np.array
, вы можете использовать метод values
:
In [418]: s.values
Out[418]: array([ 0, 4, 5, 0, 3, 10], dtype=int64)