Часто вызов функции std:: nth_element()

Я не нашел эту конкретную тему нигде...

Я вызываю алгоритм nth_element() около 400 000 раз в секунду для разных данных в std::vector из 23 целых чисел, более точных значений без знака.

Я хочу улучшить скорость вычислений, и этот конкретный вызов требует значительной части времени процессора. Теперь я заметил, что, как и в случае std:: sort(), функция nth_element видна в профилировщике даже с самым высоким уровнем оптимизации и режимом NDEBUG (компилятор Linux Clang), поэтому сравнение является встроенным, но не вызовом функции. Ну, больше preise: не nth_element(), но std:: __ introselect() виден.

Поскольку размер данных мал, я экспериментировал с использованием квадратичной функции сортировки PIKSORT, которая часто быстрее, чем вызов std:: sort, когда размер данных составляет менее 20 элементов, вероятно, потому, что функция будет встроенной.

template <class CONTAINER>
inline void piksort(CONTAINER& arr)  // indeed this is "insertion sort"
{
    typename CONTAINER::value_type a;

    const int n = (int)arr.size();
    for (int j = 1; j<n; ++j) {
        a = arr[j];
        int i = j;
        while (i > 0 && a < arr[i - 1]) {
            arr[i] = arr[i - 1];
            i--;
        }
        arr[i] = a;
    }
}

Однако это было медленнее, чем использование nth_element в этом случае.

Кроме того, использование статистического метода не подходит, Что-то быстрее, чем std:: nth_element

Наконец, поскольку значения находятся в диапазоне от 0 до 20000, метод гистограммы не подходит.

Мой вопрос: кто-нибудь знает простое решение? Я думаю, что я, вероятно, не единственный, кто очень часто вызывает std:: sort или nth_element.

Ответы

Ответ 1

Вы упомянули, что размер массива всегда известен как 23. Кроме того, используемый тип unsigned short. В этом случае вы можете попытаться использовать сортировочную сеть размером 23; поскольку ваш тип unsigned short, сортировка всего массива с помощью сортировочной сети может быть даже быстрее, чем частичная сортировка с помощью std::nth_element. Вот очень простая реализация С++ 14 сортировочной сети размером 23 с 118 единицами обмена-обменом, как описано Использование симметрии и эволюции Поиск для минимизации сортировочных сетей:

template<typename RandomIt, typename Compare = std::less<>>
void network_sort23(RandomIt first, Compare compare={})
{
    swap_if(first[1u], first[20u], compare);
    swap_if(first[2u], first[21u], compare);
    swap_if(first[5u], first[13u], compare);
    swap_if(first[9u], first[17u], compare);
    swap_if(first[0u], first[7u], compare);
    swap_if(first[15u], first[22u], compare);
    swap_if(first[4u], first[11u], compare);
    swap_if(first[6u], first[12u], compare);
    swap_if(first[10u], first[16u], compare);
    swap_if(first[8u], first[18u], compare);
    swap_if(first[14u], first[19u], compare);
    swap_if(first[3u], first[8u], compare);
    swap_if(first[4u], first[14u], compare);
    swap_if(first[11u], first[18u], compare);
    swap_if(first[2u], first[6u], compare);
    swap_if(first[16u], first[20u], compare);
    swap_if(first[0u], first[9u], compare);
    swap_if(first[13u], first[22u], compare);
    swap_if(first[5u], first[15u], compare);
    swap_if(first[7u], first[17u], compare);
    swap_if(first[1u], first[10u], compare);
    swap_if(first[12u], first[21u], compare);
    swap_if(first[8u], first[19u], compare);
    swap_if(first[17u], first[22u], compare);
    swap_if(first[0u], first[5u], compare);
    swap_if(first[20u], first[21u], compare);
    swap_if(first[1u], first[2u], compare);
    swap_if(first[18u], first[19u], compare);
    swap_if(first[3u], first[4u], compare);
    swap_if(first[21u], first[22u], compare);
    swap_if(first[0u], first[1u], compare);
    swap_if(first[19u], first[22u], compare);
    swap_if(first[0u], first[3u], compare);
    swap_if(first[12u], first[13u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[6u], first[15u], compare);
    swap_if(first[7u], first[16u], compare);
    swap_if(first[8u], first[11u], compare);
    swap_if(first[11u], first[14u], compare);
    swap_if(first[4u], first[11u], compare);
    swap_if(first[6u], first[8u], compare);
    swap_if(first[14u], first[16u], compare);
    swap_if(first[17u], first[20u], compare);
    swap_if(first[2u], first[5u], compare);
    swap_if(first[9u], first[12u], compare);
    swap_if(first[10u], first[13u], compare);
    swap_if(first[15u], first[18u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[4u], first[7u], compare);
    swap_if(first[20u], first[21u], compare);
    swap_if(first[1u], first[2u], compare);
    swap_if(first[7u], first[15u], compare);
    swap_if(first[3u], first[9u], compare);
    swap_if(first[13u], first[19u], compare);
    swap_if(first[16u], first[18u], compare);
    swap_if(first[8u], first[14u], compare);
    swap_if(first[4u], first[6u], compare);
    swap_if(first[18u], first[21u], compare);
    swap_if(first[1u], first[4u], compare);
    swap_if(first[19u], first[21u], compare);
    swap_if(first[1u], first[3u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[11u], first[13u], compare);
    swap_if(first[2u], first[6u], compare);
    swap_if(first[16u], first[20u], compare);
    swap_if(first[4u], first[9u], compare);
    swap_if(first[13u], first[18u], compare);
    swap_if(first[19u], first[20u], compare);
    swap_if(first[2u], first[3u], compare);
    swap_if(first[18u], first[20u], compare);
    swap_if(first[2u], first[4u], compare);
    swap_if(first[5u], first[17u], compare);
    swap_if(first[12u], first[14u], compare);
    swap_if(first[8u], first[12u], compare);
    swap_if(first[5u], first[7u], compare);
    swap_if(first[15u], first[17u], compare);
    swap_if(first[5u], first[8u], compare);
    swap_if(first[14u], first[17u], compare);
    swap_if(first[3u], first[5u], compare);
    swap_if(first[17u], first[19u], compare);
    swap_if(first[3u], first[4u], compare);
    swap_if(first[18u], first[19u], compare);
    swap_if(first[6u], first[10u], compare);
    swap_if(first[11u], first[16u], compare);
    swap_if(first[13u], first[16u], compare);
    swap_if(first[6u], first[9u], compare);
    swap_if(first[16u], first[17u], compare);
    swap_if(first[5u], first[6u], compare);
    swap_if(first[4u], first[5u], compare);
    swap_if(first[7u], first[9u], compare);
    swap_if(first[17u], first[18u], compare);
    swap_if(first[12u], first[15u], compare);
    swap_if(first[14u], first[15u], compare);
    swap_if(first[8u], first[12u], compare);
    swap_if(first[7u], first[8u], compare);
    swap_if(first[13u], first[15u], compare);
    swap_if(first[15u], first[17u], compare);
    swap_if(first[5u], first[7u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[10u], first[14u], compare);
    swap_if(first[6u], first[11u], compare);
    swap_if(first[14u], first[16u], compare);
    swap_if(first[15u], first[16u], compare);
    swap_if(first[6u], first[7u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[9u], first[12u], compare);
    swap_if(first[11u], first[13u], compare);
    swap_if(first[13u], first[14u], compare);
    swap_if(first[8u], first[9u], compare);
    swap_if(first[7u], first[8u], compare);
    swap_if(first[14u], first[15u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[8u], first[9u], compare);
    swap_if(first[12u], first[14u], compare);
    swap_if(first[11u], first[12u], compare);
    swap_if(first[12u], first[13u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[11u], first[12u], compare);
}

Функция утилиты swap_if сравнивает два параметра x и y с предикатом compare и заменяет их, если compare(y, x). В моем примере используется функция aa generic swap_if, но вы можете использовать оптимизированную версию, если знаете, что в любом случае вы будете сравнивать значения unsigned short с operator< (вам может не понадобиться такая функция, если ваш компилятор распознает и оптимизирует сравнение -exchange, но, к сожалению, не все компиляторы делают это - я использую g++ 5.2 с -O3, и мне все еще нужна следующая функция для производительности):

void swap_if(unsigned short& x, unsigned short& y)
{
    unsigned short dx = x;
    unsigned short dy = y;
    unsigned short tmp = x = std::min(dx, dy);
    y ^= dx ^ tmp;
}

Теперь, чтобы убедиться, что это действительно быстрее, я решил время std::nth_element, когда требуется частичное сортировать только первые 10 элементов и сортировать все 23 элемента с сетью сортировки (1000000 раз с различными перетасованными массивами). Вот что я получаю:

std::nth_element    1158ms
network_sort23      487ms

Тем не менее, мой компьютер работает некоторое время и немного медленнее, но разница в производительности является аккуратной. Я считаю, что эта разница останется прежней при перезагрузке компьютера. Я могу попробовать это позже и сообщить вам.

Относительно того, как эти моменты были сгенерированы, я использовал измененную версию этот тест из моего библиотека cpp-сортировки. Исходная сеть сортировки и функции swap_if также поступают оттуда, поэтому вы можете быть уверены, что они были протестированы несколько раз:)

EDIT: вот результаты, которые я сейчас перезапустил на своем компьютере. Версия network_sort23 по-прежнему в два раза быстрее, чем std::nth_element:

std::nth_element    369ms
network_sort23      154ms

EDIT²: если все, что вам нужно в медиане, вы можете тривиально удалить единицы обмена обмена, которые не нужны для вычисления конечного значения, которое будет на 11-й позиции. В полученной медианной сети определения размера 23, которая используется ниже, используется другая сортировочная сеть размера 23, чем предыдущая, и она дает несколько лучшие результаты:

swap_if(first[0u], first[1u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[4u], first[5u], compare);
swap_if(first[6u], first[7u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[10u], first[11u], compare);
swap_if(first[1u], first[3u], compare);
swap_if(first[5u], first[7u], compare);
swap_if(first[9u], first[11u], compare);
swap_if(first[0u], first[2u], compare);
swap_if(first[4u], first[6u], compare);
swap_if(first[8u], first[10u], compare);
swap_if(first[1u], first[2u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[5u], first[9u], compare);
swap_if(first[2u], first[6u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[3u], first[7u], compare);
swap_if(first[4u], first[8u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[1u], first[4u], compare);
swap_if(first[7u], first[10u], compare);
swap_if(first[3u], first[8u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[2u], first[4u], compare);
swap_if(first[7u], first[9u], compare);
swap_if(first[3u], first[5u], compare);
swap_if(first[6u], first[8u], compare);
swap_if(first[3u], first[4u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[7u], first[8u], compare);
swap_if(first[12u], first[13u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[16u], first[17u], compare);
swap_if(first[18u], first[19u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[13u], first[15u], compare);
swap_if(first[17u], first[19u], compare);
swap_if(first[12u], first[14u], compare);
swap_if(first[16u], first[18u], compare);
swap_if(first[20u], first[22u], compare);
swap_if(first[13u], first[14u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[21u], first[22u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[17u], first[21u], compare);
swap_if(first[14u], first[18u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[15u], first[19u], compare);
swap_if(first[16u], first[20u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[13u], first[16u], compare);
swap_if(first[19u], first[22u], compare);
swap_if(first[15u], first[20u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[14u], first[16u], compare);
swap_if(first[19u], first[21u], compare);
swap_if(first[15u], first[17u], compare);
swap_if(first[18u], first[20u], compare);
swap_if(first[15u], first[16u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[19u], first[20u], compare);
swap_if(first[0u], first[12u], compare);
swap_if(first[2u], first[14u], compare);
swap_if(first[4u], first[16u], compare);
swap_if(first[6u], first[18u], compare);
swap_if(first[8u], first[20u], compare);
swap_if(first[10u], first[22u], compare);
swap_if(first[2u], first[12u], compare);
swap_if(first[10u], first[20u], compare);
swap_if(first[4u], first[12u], compare);
swap_if(first[6u], first[14u], compare);
swap_if(first[8u], first[16u], compare);
swap_if(first[10u], first[18u], compare);
swap_if(first[8u], first[12u], compare);
swap_if(first[10u], first[14u], compare);
swap_if(first[10u], first[12u], compare);
swap_if(first[1u], first[13u], compare);
swap_if(first[3u], first[15u], compare);
swap_if(first[5u], first[17u], compare);
swap_if(first[7u], first[19u], compare);
swap_if(first[9u], first[21u], compare);
swap_if(first[3u], first[13u], compare);
swap_if(first[11u], first[21u], compare);
swap_if(first[5u], first[13u], compare);
swap_if(first[7u], first[15u], compare);
swap_if(first[9u], first[17u], compare);
swap_if(first[11u], first[19u], compare);
swap_if(first[9u], first[13u], compare);
swap_if(first[11u], first[15u], compare);
swap_if(first[11u], first[13u], compare);
swap_if(first[11u], first[12u], compare);

Есть, вероятно, более разумные способы создания сетей медианного поиска, но я не думаю, что обширные исследования были проведены по этому вопросу. Поэтому, вероятно, это лучший метод, который вы можете использовать на данный момент. Результат невелик, но он по-прежнему использует 104 единицы обмена обмена вместо 118.

Ответ 2

Общая идея

Рассматривая исходный код std::nth_element в MSVC2013, кажется, что случаи N <= 32 решаются путем сортировки вставки. Это означает, что разработчики STL поняли, что выполнение рандомизированных разделов будет медленнее, несмотря на лучшую асимптотику для этих размеров.

Одним из способов повышения производительности является оптимизация алгоритма сортировки. @Morwenn answer показывает, как сортировать 23 элемента с сортировочной сетью, которая, как известно, является одним из самых быстрых способов сортировки небольших массивов с постоянным размером. Я буду исследовать другой способ, который заключается в вычислении медианы без алгоритма сортировки. Фактически, я не буду переставлять массив ввода вообще.

Поскольку мы говорим о небольших массивах, нам нужно реализовать некоторый алгоритм O (N ^ 2) самым простым способом. В идеале, он не должен иметь никаких ветвей вообще или только хорошо предсказуемых ветвей. Кроме того, простая структура алгоритма может позволить нам векторизовать его, улучшая его производительность.

Алгоритм

Я решил следовать методу подсчета, который использовался здесь, чтобы ускорить небольшой линейный поиск. Прежде всего, предположим, что все элементы разные. Выберите любой элемент массива: количество элементов меньше, чем оно определяет свою позицию в отсортированном массиве. Мы можем выполнять итерацию по всем элементам, и для каждого из них вычислять количество элементов меньше, чем оно. Если отсортированный индекс имеет желаемое значение, мы можем остановить алгоритм.

К сожалению, в общем случае могут быть равные элементы. Нам придется сделать наш алгоритм значительно медленнее и сложнее справиться с ними. Вместо вычисления уникального отсортированного индекса элемента мы можем рассчитать интервал возможных отсортированных индексов для него. Для любого элемента достаточно подсчитать количество элементов меньше, чем оно (L), и количество элементов, равных ему (E), а затем отсортированный индекс соответствует диапазону [L, L + R). Если этот интервал содержит требуемый отсортированный индекс (т.е. N/2), то мы можем остановить алгоритм и вернуть рассмотренный элемент.

for (size_t i = 0; i < n; i++) {
    auto x = arr[i];
    //count number of "less" and "equal" elements
    int cntLess = 0, cntEq = 0;
    for (size_t j = 0; j < n; j++) {
        cntLess += arr[j] < x;
        cntEq += arr[j] == x;
    }
    //fast range checking from here: https://stackoverflow.com/a/17095534/556899
    if ((unsigned int)(idx - cntLess) < cntEq)
        return x;
}

Векторизация

Построенный алгоритм имеет только одну ветвь, которая является довольно предсказуемой: она терпит неудачу во всех случаях, за исключением единственного случая, когда мы останавливаем алгоритм. Алгоритм легко векторизовать, используя 8 элементов на регистр SSE. Поскольку нам нужно будет получить доступ к некоторым элементам после последнего, я буду считать, что входной массив дополняется значениями max = 2 ^ 15-1 до 24 или 32 элементов.

Первый способ заключается в векторизации внутреннего цикла на j. В этом случае внутренний цикл будет выполняться только 3 раза, но после его завершения необходимо выполнить два сокращения по ширине. Они едят больше времени, чем внутренняя петля. В результате такая векторизация не очень эффективна.

Второй способ заключается в векторизации внешнего цикла на i. В этом случае мы обрабатываем 8 элементов x = arr[i] сразу. Для каждого пакета мы сравниваем его с каждым элементом arr[j] во внутреннем цикле. После внутреннего цикла мы выполняем векторизованную проверку диапазона для всего пакета из 8 элементов. Если какой-либо из них преуспевает, мы определяем точное число с простым скалярным кодом (он все равно мало ходит).

__m128i idxV = _mm_set1_epi16(idx);
for (size_t i = 0; i < n; i += 8) {
    //load pack of 8 elements
    auto xx = _mm_loadu_si128((__m128i*)&arr[i]);
    //count number of less/equal elements for each element in the pack
    __m128i cntLess = _mm_setzero_si128();
    __m128i cntEq = _mm_setzero_si128();
    for (size_t j = 0; j < n; j++) {
        __m128i vAll = _mm_set1_epi16(arr[j]);
        cntLess = _mm_sub_epi16(cntLess, _mm_cmplt_epi16(vAll, xx));
        cntEq = _mm_sub_epi16(cntEq, _mm_cmpeq_epi16(vAll, xx));
    }
    //perform range check for 8 elements at once
    __m128i mask = _mm_andnot_si128(_mm_cmplt_epi16(idxV, cntLess), _mm_cmplt_epi16(idxV, _mm_add_epi16(cntLess, cntEq)));
    if (int bm = _mm_movemask_epi8(mask)) {
        //range check succeeds for one of the elements, find and return it 
        for (int t = 0; t < 8; t++)
            if (bm & (1 << (2*t)))
                return arr[i + t];
    }
}

Здесь мы видим _mm_set1_epi16 внутреннюю в самом внутреннем цикле. У GCC, похоже, есть некоторые проблемы с производительностью. Во всяком случае, это время есть на каждой внутренней итерации, которая может быть уменьшена, если мы обрабатываем 8 элементов одновременно в самой внутренней петле. В этом случае мы можем сделать одну векторизованную нагрузку и 14 команд распаковать для получения vAll для восьми элементов. Кроме того, нам нужно будет написать код сравнения и подсчета для восьми элементов в теле цикла, поэтому он также работает как 8-кратный разворот. Полученный код является самым быстрым, ссылка на него приведена ниже.

Сравнение

Я сравнивал различные решения на процессоре Ivy Bridge 3.4 Ghz. Ниже вы можете увидеть общее время вычисления для 2 ^ 23 ~ = 8M вызовов в секундах (первое число). Второе число - контрольная сумма результатов.

Результаты MSVC 2013 x64 (/O2):

memcpy only: 0.020
std::nth_element: 2.110 (1186136064)
network sort: 0.630 (1186136064)              //solution by @Morwenn (I had to change swap_if)
trivial count: 2.266 (1186136064)             //scalar algorithm (presented above)
vectorized count: 0.692 (1186136064)          //vectorization by j
vectorized count (T): 0.602 (1186136064)      //vectorization by i (presented above)
vectorized count (both): 0.450 (1186136064)   //vectorization by i and j

Результаты на MinGW GCC 4.8.3 x64 (-O3 -msse4):

memcpy only: 0.016
std::nth_element: 1.981 (1095237632)
network sort: 0.531 (1095237632)              //original swap_if used
trivial count: 1.482 (1095237632)
vectorized count: 0.655 (1095237632)
vectorized count (T): 2.668 (1095237632)      //GCC generates some crap
vectorized count (both): 0.374 (1095237632)

Как видите, предложенный алгоритм для 23 16-разрядных элементов немного быстрее, чем метод на основе сортировки (BTW, на более раннем процессоре я вижу только 5% разницу во времени). Если вы можете гарантировать, что все элементы разные, вы можете упростить алгоритм, сделав его еще быстрее.

Полный код всех алгоритмов доступен здесь, включая весь тестовый код.

Ответ 3

Я нашел эту проблему интересной, поэтому я попробовал все алгоритмы, о которых я мог думать.
Вот результаты:

testing 100000 repetitions
variant 0, no-op (for overhead measure)
5 ms
variant 1, vector + nth_element
205 ms
variant 2, multiset + advance
745 ms
variant 2b, set (not fully conformant)
787 ms
variant 3, list + lower_bound
589 ms
variant 3b, list + block-allocator
269 ms
variant 4, avl-tree + insert_sorted
645 ms
variant 4b, avl-tree + prune
682 ms
variant 5, histogram
1429 ms

Я думаю, мы можем заключить, что вы, где уже используете самые быстрые алгоритм. Мальчик был я неправ. Однако, если вы можете принять приблизительный ответ, вероятно, есть более быстрые способы, такие как медиана медианов.
Если вам интересно, источник здесь.