Numpy объединить отсортированный массив в новый массив?
Есть ли способ сделать что-то вроде merge в mergesort, используя функцию numpy?
некоторая функция, например merge:
a = np.array([1,3,5])
b = np.array([2,4,6])
c = merge(a, b) # c == np.array([1,2,3,4,5,6])
Мне хотелось бы получить высокую производительность для больших данных благодаря numpy
Ответы
Ответ 1
Вы можете использовать
c = concatenate((a,b))
c.sort(kind='mergesort')
Я боюсь, что вы не сможете сделать это лучше, если только вы не напишите свою собственную функцию сортировки как расширение python, а cython
.
См. этот вопрос для аналогичной проблемы, но сохраняя только уникальные значения в объединенном массиве. Оценки и комментарии там также проницательны.
Ответ 2
Когда один массив значительно больше другого, приличное ускорение (в 5 раз на моем компьютере) можно получить, выполнив команду np.searchorted, скорость которой ограничена, главным образом, путем поиска индексов вставки меньшего массива:
import numpy as np
def classic_merge(a, b):
c = np.concatenate((a,b))
c.sort(kind='mergesort')
return c
def new_merge(a, b):
if len(a) < len(b):
b, a = a, b
c = np.empty(len(a) + len(b), dtype=a.dtype)
b_indices = np.arange(len(b)) + np.searchsorted(a, b)
a_indices = np.ones(len(c), dtype=bool)
a_indices[b_indices] = False
c[b_indices] = b
c[a_indices] = a
return c
Сроки дает:
from timeit import timeit as t
results = []
for size_digits in range(2, 8):
size = 10**size_digits
# size difference of a factor 10 here makes the difference!
a = np.arange(size // 10, dtype=np.int)
b = np.arange(size, dtype=np.int)
classic = t(lambda: classic_merge(a, b), number=10)
new = t(lambda: new_merge(a, b), number=10)
results.append((size_digits, classic, new))
if True:
text_format = " ".join(["{:<15}"] * 3)
print(text_format.format("log10(size)", "Classic", "New"))
table_format = " ".join(["{:.5f}"] * 3)
for result in results:
print(table_format.format(*result))
log10(size) Classic New
2.00000 0.00009 0.00027
3.00000 0.00021 0.00030
4.00000 0.00233 0.00082
5.00000 0.02827 0.00601
6.00000 0.33322 0.06059
7.00000 4.40571 0.86764
Когда a и b примерно одинаковы, различия в длине меньше:
from timeit import timeit as t
results = []
for size_digits in range(2, 8):
size = 10**size_digits
# same size
a = np.arange(size , dtype=np.int)
b = np.arange(size, dtype=np.int)
classic = t(lambda: classic_merge(a, b), number=10)
new = t(lambda: new_merge(a, b), number=10)
results.append((size_digits, classic, new))
if True:
text_format = " ".join(["{:<15}"] * 3)
print(text_format.format("log10(size)", "Classic", "New"))
table_format = " ".join(["{:.5f}"] * 3)
for result in results:
print(table_format.format(*result))
log10(size) Classic New
2.00000 0.00026 0.00087
3.00000 0.00108 0.00182
4.00000 0.01257 0.01243
5.00000 0.16333 0.12692
6.00000 1.05006 0.49186
7.00000 8.35967 5.93732