Фильтровать строки массива numpy?

Я хочу применить функцию к каждой строке массива numpy. Если эта функция будет равна true, я сохраню строку, иначе я ее отброшу. Например, моя функция может быть:

def f(row):
    if sum(row)>10: return True
    else: return False

Мне было интересно, было ли что-то похожее на:

np.apply_over_axes()

который применяет функцию к каждой строке массива numpy и возвращает результат. Я надеялся на что-то вроде:

np.filter_over_axes()

который применял бы функцию к каждой строке массива numpy и возвращал бы строки, для которых функция возвращала true. Есть ли что-нибудь подобное? Или я должен просто использовать цикл for?

Ответы

Ответ 1

В идеале вы сможете реализовать векторизованную версию своей функции и использовать ее для булевской индексации. Для подавляющего большинства проблем это правильное решение. Numpy предоставляет довольно много функций, которые могут действовать по различным осям, а также все основные операции и сравнения, поэтому большинство полезных условий должны быть векторизуемыми.

import numpy as np

x = np.random.randn(20, 3)
x_new = x[np.sum(x, axis=1) > .5]

Если вы абсолютно уверены, что не можете сделать выше, я бы предложил использовать понимание списка (или np.apply_along_axis), чтобы создать массив bools индексировать с помощью.

def myfunc(row):
    return sum(row) > .5

bool_arr = np.array([myfunc(row) for row in x])
x_new = x[bool_arr]

Это позволит выполнить работу относительно чистым способом, но будет значительно медленнее, чем векторная версия. Пример:

x = np.random.randn(5000, 200)

%timeit x[np.sum(x, axis=1) > .5]
# 100 loops, best of 3: 5.71 ms per loop

%timeit x[np.array([myfunc(row) for row in x])]
# 1 loops, best of 3: 217 ms per loop