Как оптимизировать этот код Haskell, суммируя простые числа в сублинейном времени?

Задача 10 из Project Euler заключается в том, чтобы найти сумму всех простых чисел, указанных ниже.

Я решил это просто путем суммирования простых чисел, генерируемых ситом Эратосфена. Затем я встретил гораздо более эффективное решение Lucy_Hedgehog (сублинейное!).

При n = 2⋅10 ^ 9:

  • Код Python (из приведенной выше цитаты) выполняется через 1.2 секунды в Python 2.7.3.

  • Код С++ (мой) работает примерно через 0,3 секунды (скомпилирован с g++ 4.8.4).

Я повторил тот же алгоритм в Haskell, так как я его изучаю:

import Data.List

import Data.Map (Map, (!))
import qualified Data.Map as Map

problem10 :: Integer -> Integer
problem10 n = (sieve (Map.fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]) 2 r vs) ! n
              where vs = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]
                    r  = floor (sqrt (fromIntegral n))

sieve :: Map Integer Integer -> Integer -> Integer -> [Integer] -> Map Integer Integer
sieve m p r vs | p > r     = m
               | otherwise = sieve (if m ! p > m ! (p - 1) then update m vs p else m) (p + 1) r vs

update :: Map Integer Integer -> [Integer] -> Integer -> Map Integer Integer
update m vs p = foldl' decrease m (map (\v -> (v, sumOfSieved m v p)) (takeWhile (>= p*p) vs))

decrease :: Map Integer Integer -> (Integer, Integer) -> Map Integer Integer
decrease m (k, v) = Map.insertWith (flip (-)) k v m

sumOfSieved :: Map Integer Integer -> Integer -> Integer -> Integer
sumOfSieved m v p = p * (m ! (v `div` p) - m ! (p - 1))

main = print $ problem10 $ 2*10^9

Я скомпилировал его с помощью ghc -O2 10.hs и запустил с помощью time ./10.

Он дает правильный ответ, но занимает около 7 секунд.

Я скомпилировал его с помощью ghc -prof -fprof-auto -rtsopts 10 и запустил с помощью ./10 +RTS -p -h.

10.prof показывает, что decrease занимает 52,2% времени и 67,5% отчислений.

После запуска hp2ps 10.hp у меня появился такой профиль кучи:

hp

Снова выглядит как decrease занимает большую часть кучи. GHC версия 7.6.3.

Как бы вы оптимизировали время выполнения этого кода Haskell?


Обновление 13.06.17:

I попробовал заменить неизменяемый Data.Map на mutable Data.HashTable.IO.BasicHashTable из пакета hashtables, но я, вероятно, делаю что-то плохое, так как для крошечного n = 30 он занимает слишком много времени, около 10 секунд. Что не так?

Обновление 18.06.17:

Интересно отметить проблемы производительности HashTable. Я взял Sherh код, используя измененный Data.HashTable.ST.Linear, но отброшен Data.Judy вместо. Он работает через 1,1 секунды, все еще относительно медленно.

Ответы

Ответ 1

Я сделал небольшие улучшения, поэтому он запускается в 3.4-3.5 секунд на моей машине. Использование IntMap.Strict помогло. Кроме этого, я просто выполнил несколько оптимизаций ghc, чтобы быть уверенным. И сделайте код Haskell ближе к коду Python из вашей ссылки. В качестве следующего шага вы можете попытаться использовать некоторые mutable HashMap. Но я не уверен... IntMap не может быть намного быстрее, чем какой-либо изменчивый контейнер, потому что он неизменный. Хотя я все еще удивлен этой эффективностью. Я надеюсь, что это может быть реализовано быстрее.

Вот код:

import Data.List (foldl')
import Data.IntMap.Strict (IntMap, (!))
import qualified Data.IntMap.Strict as IntMap

p :: Int -> Int
p n = (sieve (IntMap.fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]) 2 r vs) ! n
               where vs = [n `div` i | i <- [1..r]] ++ [n', n' - 1 .. 1]
                     r  = floor (sqrt (fromIntegral n) :: Double)
                     n' = n `div` r - 1

sieve :: IntMap Int -> Int -> Int -> [Int] -> IntMap Int
sieve m' p' r vs = go m' p'
  where
    go m p | p > r               = m
           | m ! p > m ! (p - 1) = go (update m vs p) (p + 1)
           | otherwise           = go m (p + 1)

update :: IntMap Int -> [Int] -> Int -> IntMap Int
update s vs p = foldl' decrease s (takeWhile (>= p2) vs)
  where
    sp = s ! (p - 1)
    p2 = p * p
    sumOfSieved v = p * (s ! (v `div` p) - sp)
    decrease m  v = IntMap.adjust (subtract $ sumOfSieved v) v m

main :: IO ()
main = print $ p $ 2*10^(9 :: Int) 

UPDATE:

Использование изменчивого hashtables Мне удалось сделать производительность до ~5.5sec на Haskell с этой реализацией.

Кроме того, я использовал ненужные векторы вместо списков в нескольких местах. Linear хеширование кажется самым быстрым. Я думаю, что это можно сделать еще быстрее. Я заметил sse42 вариант в пакете hasthables. Не уверен, что мне удалось установить его правильно, но даже без него это происходит быстро.

ОБНОВЛЕНИЕ 2 (19.06.2017)

Мне удалось сделать это 3x быстрее, чем лучшее решение от @Krom (используя мой код + его карту), вообще-то сбросив суффикс judy. Вместо этого используются только простые массивы. Вы можете придумать ту же идею, если заметите, что ключи для hashmap являются либо последовательностью от 1 до n', либо n div i для i от 1 до r. Таким образом, мы можем представить такой HashMap как два массива, делающих поиск в массиве в зависимости от ключа поиска.

Мой код + Judy HashMap

$ time ./judy
95673602693282040

real    0m0.590s
user    0m0.588s
sys     0m0.000s

Мой код + моя разреженная карта

$ time ./sparse
95673602693282040

real    0m0.203s
user    0m0.196s
sys     0m0.004s

Это можно сделать еще быстрее, если вместо IOUArray уже созданы векторы и библиотека Vector, а readArray заменяется на unsafeRead. Но я не думаю, что это должно быть сделано, если только вам не очень интересно оптимизировать это как можно больше.

Сравнение с этим решением - обман и нечестно. Я ожидаю, что те же идеи, реализованные в Python и С++, будут еще быстрее. Но решение @Krom с закрытым хэшмапом уже обманывает, потому что оно использует настраиваемую структуру данных, а не стандартную. По крайней мере, вы можете видеть, что стандартные и самые популярные хэш-карты в Haskell не так быстро. Использование более эффективных алгоритмов и улучшенных структур ad-hoc может быть лучше для таких проблем.

Здесь в результате код.

Ответ 2

Сначала как базовый уровень, сроки существующих подходов на моей машине:

  • Оригинальная программа, размещенная в вопросе:

    time stack exec primorig
    95673602693282040
    
    real    0m4.601s
    user    0m4.387s
    sys     0m0.251s
    
  • Вторая версия с использованием Data.IntMap.Strict от здесь

    time stack exec primIntMapStrict
    95673602693282040
    
    real    0m2.775s
    user    0m2.753s
    sys     0m0.052s
    
  • Код Shershs с Data.Judy упал в здесь

    time stack exec prim-hash2
    95673602693282040
    
    real    0m0.945s
    user    0m0.955s
    sys     0m0.028s
    
  • Ваше решение python.

    Я скомпилировал его с помощью

    python -O -m py_compile problem10.py
    

    и время:

    time python __pycache__/problem10.cpython-36.opt-1.pyc
    95673602693282040
    
    real    0m1.163s
    user    0m1.160s
    sys     0m0.003s
    
  • Ваша версия на С++:

    $ g++ -O2 --std=c++11 p10.cpp -o p10
    $ time ./p10
    sum(2000000000) = 95673602693282040
    
    real    0m0.314s
    user    0m0.310s
    sys     0m0.003s
    

Я не потрудился предоставить базовый уровень для slow.hs, поскольку я не хотите дождаться, когда он будет завершен, когда будет запущен аргумент 2*10^9.

Достаточная производительность

Следующая программа работает под вторым на моей машине.

Использует ручной хэш файл, который использует замкнутое хеширование с помощью линейного зондирования и использует некоторый вариант хэш-функции knuths, см. здесь.

Конечно, это несколько скроено для случая, так как поиск например, ожидает, что найденные ключи будут присутствовать.

Тайминги:

time stack exec prim
95673602693282040

real    0m0.725s
user    0m0.714s
sys     0m0.047s

Сначала я реализовал ручной хэш файл, просто хэш клавиши с

key `mod` size

и выберите размер, который в несколько раз превышает ожидаемый вход, но программа заняла 22 или более баллов.

Наконец, речь шла о выборе хэш-функции, которая была хорошо для рабочей нагрузки.

Вот программа:

import Data.Maybe
import Control.Monad
import Data.Array.IO
import Data.Array.Base (unsafeRead)

type Number = Int

data Map = Map { keys :: IOUArray Int Number
               , values :: IOUArray Int Number
               , size :: !Int 
               , factor :: !Int
               }

newMap :: Int -> Int -> IO Map
newMap s f = do
  k <- newArray (0, s-1) 0
  v <- newArray (0, s-1) 0
  return $ Map k v s f 

storeKey :: IOUArray Int Number -> Int -> Int -> Number -> IO Int
storeKey arr s f key = go ((key * f) `mod` s)
  where
    go :: Int -> IO Int
    go ind = do
      v <- readArray arr ind
      go2 v ind
    go2 v ind
      | v == 0    = do { writeArray arr ind key; return ind; }
      | v == key  = return ind
      | otherwise = go ((ind + 1) `mod` s)

loadKey :: IOUArray Int Number -> Int -> Int -> Number -> IO Int
loadKey arr s f key = s `seq` key `seq` go ((key *f) `mod` s)
  where
    go :: Int -> IO Int
    go ix = do
      v <- unsafeRead arr ix
      if v == key then return ix else go ((ix + 1) `mod` s)

insertIntoMap :: Map -> (Number, Number) -> IO Map
insertIntoMap [email protected](Map ks vs s f) (k, v) = do
  ix <- storeKey ks s f k
  writeArray vs ix v
  return m

fromList :: Int -> Int -> [(Number, Number)] -> IO Map
fromList s f xs = do
  m <- newMap s f
  foldM insertIntoMap m xs

(!) :: Map -> Number -> IO Number
(!) (Map ks vs s f) k = do
  ix <- loadKey ks s f k
  readArray vs ix

mupdate :: Map -> Number -> (Number -> Number) -> IO ()
mupdate (Map ks vs s fac) i f = do
  ix <- loadKey ks s fac i
  old <- readArray vs ix
  let x' = f old
  x' `seq` writeArray vs ix x'

r' :: Number -> Number
r'  = floor . sqrt . fromIntegral

vs' :: Integral a => a -> a -> [a]
vs' n r = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]  

vss' n r = r + n `div` r -1

list' :: Int -> Int -> [Number] -> IO Map
list' s f vs = fromList s f [(i, i * (i + 1) `div` 2 - 1) | i <- vs]

problem10 :: Number -> IO Number
problem10 n = do
      m <- list' (19*vss) (19*vss+7) vs
      nm <- sieve m 2 r vs
      nm ! n
    where vs = vs' n r
          vss = vss' n r
          r  = r' n

sieve :: Map -> Number -> Number -> [Number] -> IO Map
sieve m p r vs | p > r     = return m
               | otherwise = do
                   v1 <- m ! p
                   v2 <- m ! (p - 1)
                   nm <- if v1 > v2 then update m vs p else return m
                   sieve nm (p + 1) r vs

update :: Map -> [Number] -> Number -> IO Map
update m vs p = foldM (decrease p) m $ takeWhile (>= p*p) vs

decrease :: Number -> Map -> Number -> IO Map
decrease p m k = do
  v <- sumOfSieved m k p
  mupdate m k (subtract v)
  return m

sumOfSieved :: Map -> Number -> Number -> IO Number
sumOfSieved m v p = do
  v1 <- m ! (v `div` p)
  v2 <- m ! (p - 1)
  return $ p * (v1 - v2)

main = do { n <- problem10 (2*10^9) ; print n; } -- 2*10^9

Я не профессионал с хэшированием и такими вещами, поэтому это, безусловно, может быть улучшено. Может быть, мы, Хаскеллеры, должны улучшить хеш-карты полки или предоставить несколько более простых.

My hashmap, код Shershs

Если я подключу свой хэш файл в коде Shershs (см. ниже), см. здесь мы даже до

time stack exec prim-hash2
95673602693282040

real    0m0.601s
user    0m0.604s
sys     0m0.034s

Почему slow.hs медленно?

Если вы читаете источник для функции insert в Data.HashTable.ST.Basic, вы увидит, что он удаляет старую пару значений ключа и вставляет новенький. Он не ищет "место" для ценности и мутировать его, как можно себе представить, если прочитать, что это "изменчивая" хэш-таблица. Здесь хэш-таблица является изменчивой, поэтому вам не нужно копировать всю хэш-таблицу для вставки новой пары значений ключа, но значения для пар не. Я не знаю, если это вся история slow.hs будучи медленным, но я предполагаю, что это довольно большая его часть.

Несколько незначительных улучшений

Итак, идея, которую я выполнял, пытаясь улучшить вашей программы в первый раз.

См., вам не требуется измененное сопоставление от ключей к значениям. Ваш набор ключей исправлен. Вы хотите, чтобы сопоставление ключей с изменяемыми мест. (Это, кстати, то, что вы получаете от С++ по умолчанию.)

И поэтому я попытался придумать это. Я использовал IntMap IORef из Data.IntMap.Strict и Data.IORef сначала и получили время из

tack exec prim
95673602693282040

real    0m2.134s
user    0m2.141s
sys     0m0.028s

Я подумал, может быть, это поможет работать с unboxed values и для этого я использовал IOUArray Int Int с 1 элементом каждый вместо IORef и получил эти тайминги:

time stack exec prim
95673602693282040

real    0m2.015s
user    0m2.018s
sys     0m0.038s

Не большая разница, и поэтому я попытался избавиться от границ проверка в 1 массиве элементов с помощью unsafeRead и unsafeWrite и получил время

time stack exec prim
95673602693282040

real    0m1.845s
user    0m1.850s
sys     0m0.030s

который был лучшим, я использовал Data.IntMap.Strict.

Конечно, я запускал каждую программу несколько раз, чтобы узнать, времена стабильны и различия во времени выполнения не просто шум.

Похоже, что это всего лишь микро-оптимизация.

И вот программа, которая работает быстрее для меня, не используя ручную структуру данных:

import qualified Data.IntMap.Strict as M
import Control.Monad
import Data.Array.IO
import Data.Array.Base (unsafeRead, unsafeWrite)

type Number = Int
type Place = IOUArray Number Number
type Map = M.IntMap Place

tupleToRef :: (Number, Number) -> IO (Number, Place)
tupleToRef = traverse (newArray (0,0))

insertRefs :: [(Number, Number)] -> IO [(Number, Place)]
insertRefs = traverse tupleToRef

fromList :: [(Number, Number)] -> IO Map 
fromList xs = M.fromList <$> insertRefs xs

(!) :: Map -> Number -> IO Number
(!) m i = unsafeRead (m M.! i) 0

mupdate :: Map -> Number -> (Number -> Number) -> IO ()
mupdate m i f = do
  let place = m M.! i
  old <- unsafeRead place 0
  let x' = f old
  -- make the application of f strict
  x' `seq` unsafeWrite place 0 x'

r' :: Number -> Number
r'  = floor . sqrt . fromIntegral

vs' :: Integral a => a -> a -> [a]
vs' n r = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]  

list' :: [Number] -> IO Map
list' vs = fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]

problem10 :: Number -> IO Number
problem10 n = do
      m <- list' vs
      nm <- sieve m 2 r vs
      nm ! n
    where vs = vs' n r
          r  = r' n

sieve :: Map -> Number -> Number -> [Number] -> IO Map
sieve m p r vs | p > r     = return m
               | otherwise = do
                   v1 <- m ! p
                   v2 <- m ! (p - 1)
                   nm <- if v1 > v2 then update m vs p else return m
                   sieve nm (p + 1) r vs

update :: Map -> [Number] -> Number -> IO Map
update m vs p = foldM (decrease p) m $ takeWhile (>= p*p) vs

decrease :: Number -> Map -> Number -> IO Map
decrease p m k = do
  v <- sumOfSieved m k p
  mupdate m k (subtract v)
  return m

sumOfSieved :: Map -> Number -> Number -> IO Number
sumOfSieved m v p = do
  v1 <- m ! (v `div` p)
  v2 <- m ! (p - 1)
  return $ p * (v1 - v2)

main = do { n <- problem10 (2*10^9) ; print n; } -- 2*10^9

Если вы прокомментируете это, вы увидите, что большую часть времени он проводит в пользовательской функции поиска (!), не знаю, как улучшить это дальше. Попытка встраивать (!) с помощью {-# INLINE (!) #-} не дали лучших результатов; возможно, ghc уже сделал это.

Ответ 3

Этот мой код оценивает сумму до 2⋅10 ^ 9 за 0,3 секунды и сумму до 10 ^ 12 (18435588552550705911377) за 19,6 секунды (при наличии достаточного количества ОЗУ).

import Control.DeepSeq 
import qualified Control.Monad as ControlMonad
import qualified Data.Array as Array
import qualified Data.Array.ST as ArrayST
import qualified Data.Array.Base as ArrayBase

primeLucy :: (Integer -> Integer) -> (Integer -> Integer) -> Integer -> (Integer->Integer)
primeLucy f sf n = g
  where
    r = fromIntegral $ integerSquareRoot n
    ni = fromIntegral n
    loop from to c = let go i = ControlMonad.when (to<=i) (c i >> go (i-1)) in go from

    k = ArrayST.runSTArray $ do
      k <- ArrayST.newListArray (-r,r) $ force $
        [sf (div n (toInteger i)) - sf 1|i<-[r,r-1..1]] ++
        [0] ++
        [sf (toInteger i) - sf 1|i<-[1..r]]
      ControlMonad.forM_ (takeWhile (<=r) primes) $ \p -> do
        l <- ArrayST.readArray k (p-1)
        let q = force $ f (toInteger p)

        let adjust = \i j -> do { v <- ArrayBase.unsafeRead k (i+r); w <- ArrayBase.unsafeRead k (j+r); ArrayBase.unsafeWrite k (i+r) $!! v+q*(l-w) }

        loop (-1)         (-div r p)              $ \i -> adjust i (i*p)
        loop (-div r p-1) (-min r (div ni (p*p))) $ \i -> adjust i (div (-ni) (i*p))
        loop r            (p*p)                   $ \i -> adjust i (div i p)

      return k

    g :: Integer -> Integer
    g m
      | m >= 1 && m <= integerSquareRoot n                       = k Array.! (fromIntegral m)
      | m >= integerSquareRoot n && m <= n && div n (div n m)==m = k Array.! (fromIntegral (negate (div n m)))
      | otherwise = error $ "Function not precalculated for value " ++ show m

primeSum :: Integer -> Integer
primeSum n = (primeLucy id (\m -> div (m*m+m) 2) n) n

Если ваша функция integerSquareRoot глючит (как сообщается, некоторые из них), вы можете заменить ее здесь floor . sqrt . fromIntegral.

Пояснение:

Как следует из названия, оно основано на обобщении известного метода "Люси-Ежик", который в конечном итоге был обнаружен оригинальным плакатом.

Он позволяет рассчитать много сумм формы sum (с p prime) без перечисления всех простых чисел до N и во времени O (N ^ 0,75).

Его входы - это функция f (т.е. id, если вы хотите получить первую сумму), ее суммирующая функция по всем целым числам (т.е. в этом случае сумма первых m целых чисел или div (m*m+m) 2) и N.

PrimeLucy возвращает функцию поиска eq (с p prime) ограничено определенными значениями n: values ​​.

Ответ 4

Попробуйте это и сообщите мне, насколько это быстро:

-- sum of primes

import Control.Monad (forM_, when)
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Unboxed

sieve :: Int -> UArray Int Bool
sieve n = runSTUArray $ do
    let m = (n-1) `div` 2
        r = floor . sqrt $ fromIntegral n
    bits <- newArray (0, m-1) True
    forM_ [0 .. r `div` 2 - 1] $ \i -> do
        isPrime <- readArray bits i
        when isPrime $ do
            let a = 2*i*i + 6*i + 3
                b = 2*i*i + 8*i + 6
            forM_ [a, b .. (m-1)] $ \j -> do
                writeArray bits j False
    return bits

primes :: Int -> [Int]
primes n = 2 : [2*i+3 | (i, True) <- assocs $ sieve n]

main = do
    print $ sum $ primes 1000000

Вы можете запустить его на ideone. Моим алгоритмом является Сито Эратосфена, и оно должно быть довольно быстрым для малых n. При n = 2 000 000 000 размер массива может быть проблемой, и в этом случае вам нужно будет использовать сегментированное сито. См. мой блог для получения дополнительной информации о сите Эратосфена. См. этот ответ для информации о сегментированном сите (но не в Haskell, к сожалению).