Как оптимизировать этот код Haskell, суммируя простые числа в сублинейном времени?
Задача 10 из Project Euler заключается в том, чтобы найти сумму всех простых чисел, указанных ниже.
Я решил это просто путем суммирования простых чисел, генерируемых ситом Эратосфена. Затем я встретил гораздо более эффективное решение Lucy_Hedgehog (сублинейное!).
При n = 2⋅10 ^ 9:
-
Код Python (из приведенной выше цитаты) выполняется через 1.2 секунды в Python 2.7.3.
-
Код С++ (мой) работает примерно через 0,3 секунды (скомпилирован с g++ 4.8.4).
Я повторил тот же алгоритм в Haskell, так как я его изучаю:
import Data.List
import Data.Map (Map, (!))
import qualified Data.Map as Map
problem10 :: Integer -> Integer
problem10 n = (sieve (Map.fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]) 2 r vs) ! n
where vs = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]
r = floor (sqrt (fromIntegral n))
sieve :: Map Integer Integer -> Integer -> Integer -> [Integer] -> Map Integer Integer
sieve m p r vs | p > r = m
| otherwise = sieve (if m ! p > m ! (p - 1) then update m vs p else m) (p + 1) r vs
update :: Map Integer Integer -> [Integer] -> Integer -> Map Integer Integer
update m vs p = foldl' decrease m (map (\v -> (v, sumOfSieved m v p)) (takeWhile (>= p*p) vs))
decrease :: Map Integer Integer -> (Integer, Integer) -> Map Integer Integer
decrease m (k, v) = Map.insertWith (flip (-)) k v m
sumOfSieved :: Map Integer Integer -> Integer -> Integer -> Integer
sumOfSieved m v p = p * (m ! (v `div` p) - m ! (p - 1))
main = print $ problem10 $ 2*10^9
Я скомпилировал его с помощью ghc -O2 10.hs
и запустил с помощью time ./10
.
Он дает правильный ответ, но занимает около 7 секунд.
Я скомпилировал его с помощью ghc -prof -fprof-auto -rtsopts 10
и запустил с помощью ./10 +RTS -p -h
.
10.prof показывает, что decrease
занимает 52,2% времени и 67,5% отчислений.
После запуска hp2ps 10.hp
у меня появился такой профиль кучи:
![hp]()
Снова выглядит как decrease
занимает большую часть кучи. GHC версия 7.6.3.
Как бы вы оптимизировали время выполнения этого кода Haskell?
Обновление 13.06.17:
I попробовал заменить неизменяемый Data.Map
на mutable Data.HashTable.IO.BasicHashTable
из пакета hashtables
, но я, вероятно, делаю что-то плохое, так как для крошечного n = 30 он занимает слишком много времени, около 10 секунд. Что не так?
Обновление 18.06.17:
Интересно отметить проблемы производительности HashTable. Я взял Sherh код, используя измененный Data.HashTable.ST.Linear
, но отброшен Data.Judy
вместо. Он работает через 1,1 секунды, все еще относительно медленно.
Ответы
Ответ 1
Я сделал небольшие улучшения, поэтому он запускается в 3.4-3.5
секунд на моей машине.
Использование IntMap.Strict
помогло. Кроме этого, я просто выполнил несколько оптимизаций ghc
, чтобы быть уверенным. И сделайте код Haskell ближе к коду Python из вашей ссылки. В качестве следующего шага вы можете попытаться использовать некоторые mutable HashMap
. Но я не уверен... IntMap
не может быть намного быстрее, чем какой-либо изменчивый контейнер, потому что он неизменный. Хотя я все еще удивлен этой эффективностью. Я надеюсь, что это может быть реализовано быстрее.
Вот код:
import Data.List (foldl')
import Data.IntMap.Strict (IntMap, (!))
import qualified Data.IntMap.Strict as IntMap
p :: Int -> Int
p n = (sieve (IntMap.fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]) 2 r vs) ! n
where vs = [n `div` i | i <- [1..r]] ++ [n', n' - 1 .. 1]
r = floor (sqrt (fromIntegral n) :: Double)
n' = n `div` r - 1
sieve :: IntMap Int -> Int -> Int -> [Int] -> IntMap Int
sieve m' p' r vs = go m' p'
where
go m p | p > r = m
| m ! p > m ! (p - 1) = go (update m vs p) (p + 1)
| otherwise = go m (p + 1)
update :: IntMap Int -> [Int] -> Int -> IntMap Int
update s vs p = foldl' decrease s (takeWhile (>= p2) vs)
where
sp = s ! (p - 1)
p2 = p * p
sumOfSieved v = p * (s ! (v `div` p) - sp)
decrease m v = IntMap.adjust (subtract $ sumOfSieved v) v m
main :: IO ()
main = print $ p $ 2*10^(9 :: Int)
UPDATE:
Использование изменчивого hashtables
Мне удалось сделать производительность до ~5.5sec
на Haskell с этой реализацией.
Кроме того, я использовал ненужные векторы вместо списков в нескольких местах. Linear
хеширование кажется самым быстрым. Я думаю, что это можно сделать еще быстрее. Я заметил sse42
вариант в пакете hasthables
. Не уверен, что мне удалось установить его правильно, но даже без него это происходит быстро.
ОБНОВЛЕНИЕ 2 (19.06.2017)
Мне удалось сделать это 3x
быстрее, чем лучшее решение от @Krom (используя мой код + его карту), вообще-то сбросив суффикс judy. Вместо этого используются только простые массивы. Вы можете придумать ту же идею, если заметите, что ключи для hashmap являются либо последовательностью от 1
до n'
, либо n div i
для i
от 1
до r
. Таким образом, мы можем представить такой HashMap как два массива, делающих поиск в массиве в зависимости от ключа поиска.
Мой код + Judy HashMap
$ time ./judy
95673602693282040
real 0m0.590s
user 0m0.588s
sys 0m0.000s
Мой код + моя разреженная карта
$ time ./sparse
95673602693282040
real 0m0.203s
user 0m0.196s
sys 0m0.004s
Это можно сделать еще быстрее, если вместо IOUArray
уже созданы векторы и библиотека Vector
, а readArray
заменяется на unsafeRead
. Но я не думаю, что это должно быть сделано, если только вам не очень интересно оптимизировать это как можно больше.
Сравнение с этим решением - обман и нечестно. Я ожидаю, что те же идеи, реализованные в Python и С++, будут еще быстрее. Но решение @Krom с закрытым хэшмапом уже обманывает, потому что оно использует настраиваемую структуру данных, а не стандартную. По крайней мере, вы можете видеть, что стандартные и самые популярные хэш-карты в Haskell не так быстро. Использование более эффективных алгоритмов и улучшенных структур ad-hoc может быть лучше для таких проблем.
Здесь в результате код.
Ответ 2
Сначала как базовый уровень, сроки существующих подходов
на моей машине:
-
Оригинальная программа, размещенная в вопросе:
time stack exec primorig
95673602693282040
real 0m4.601s
user 0m4.387s
sys 0m0.251s
-
Вторая версия с использованием Data.IntMap.Strict
от
здесь
time stack exec primIntMapStrict
95673602693282040
real 0m2.775s
user 0m2.753s
sys 0m0.052s
-
Код Shershs с Data.Judy упал в здесь
time stack exec prim-hash2
95673602693282040
real 0m0.945s
user 0m0.955s
sys 0m0.028s
-
Ваше решение python.
Я скомпилировал его с помощью
python -O -m py_compile problem10.py
и время:
time python __pycache__/problem10.cpython-36.opt-1.pyc
95673602693282040
real 0m1.163s
user 0m1.160s
sys 0m0.003s
-
Ваша версия на С++:
$ g++ -O2 --std=c++11 p10.cpp -o p10
$ time ./p10
sum(2000000000) = 95673602693282040
real 0m0.314s
user 0m0.310s
sys 0m0.003s
Я не потрудился предоставить базовый уровень для slow.hs, поскольку я не
хотите дождаться, когда он будет завершен, когда будет запущен аргумент
2*10^9
.
Достаточная производительность
Следующая программа работает под вторым на моей машине.
Использует ручной хэш файл, который использует замкнутое хеширование с помощью
линейного зондирования и использует некоторый вариант хэш-функции knuths,
см. здесь.
Конечно, это несколько скроено для случая, так как поиск
например, ожидает, что найденные ключи будут присутствовать.
Тайминги:
time stack exec prim
95673602693282040
real 0m0.725s
user 0m0.714s
sys 0m0.047s
Сначала я реализовал ручной хэш файл, просто хэш
клавиши с
key `mod` size
и выберите размер, который в несколько раз превышает ожидаемый
вход, но программа заняла 22 или более баллов.
Наконец, речь шла о выборе хэш-функции, которая была
хорошо для рабочей нагрузки.
Вот программа:
import Data.Maybe
import Control.Monad
import Data.Array.IO
import Data.Array.Base (unsafeRead)
type Number = Int
data Map = Map { keys :: IOUArray Int Number
, values :: IOUArray Int Number
, size :: !Int
, factor :: !Int
}
newMap :: Int -> Int -> IO Map
newMap s f = do
k <- newArray (0, s-1) 0
v <- newArray (0, s-1) 0
return $ Map k v s f
storeKey :: IOUArray Int Number -> Int -> Int -> Number -> IO Int
storeKey arr s f key = go ((key * f) `mod` s)
where
go :: Int -> IO Int
go ind = do
v <- readArray arr ind
go2 v ind
go2 v ind
| v == 0 = do { writeArray arr ind key; return ind; }
| v == key = return ind
| otherwise = go ((ind + 1) `mod` s)
loadKey :: IOUArray Int Number -> Int -> Int -> Number -> IO Int
loadKey arr s f key = s `seq` key `seq` go ((key *f) `mod` s)
where
go :: Int -> IO Int
go ix = do
v <- unsafeRead arr ix
if v == key then return ix else go ((ix + 1) `mod` s)
insertIntoMap :: Map -> (Number, Number) -> IO Map
insertIntoMap [email protected](Map ks vs s f) (k, v) = do
ix <- storeKey ks s f k
writeArray vs ix v
return m
fromList :: Int -> Int -> [(Number, Number)] -> IO Map
fromList s f xs = do
m <- newMap s f
foldM insertIntoMap m xs
(!) :: Map -> Number -> IO Number
(!) (Map ks vs s f) k = do
ix <- loadKey ks s f k
readArray vs ix
mupdate :: Map -> Number -> (Number -> Number) -> IO ()
mupdate (Map ks vs s fac) i f = do
ix <- loadKey ks s fac i
old <- readArray vs ix
let x' = f old
x' `seq` writeArray vs ix x'
r' :: Number -> Number
r' = floor . sqrt . fromIntegral
vs' :: Integral a => a -> a -> [a]
vs' n r = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]
vss' n r = r + n `div` r -1
list' :: Int -> Int -> [Number] -> IO Map
list' s f vs = fromList s f [(i, i * (i + 1) `div` 2 - 1) | i <- vs]
problem10 :: Number -> IO Number
problem10 n = do
m <- list' (19*vss) (19*vss+7) vs
nm <- sieve m 2 r vs
nm ! n
where vs = vs' n r
vss = vss' n r
r = r' n
sieve :: Map -> Number -> Number -> [Number] -> IO Map
sieve m p r vs | p > r = return m
| otherwise = do
v1 <- m ! p
v2 <- m ! (p - 1)
nm <- if v1 > v2 then update m vs p else return m
sieve nm (p + 1) r vs
update :: Map -> [Number] -> Number -> IO Map
update m vs p = foldM (decrease p) m $ takeWhile (>= p*p) vs
decrease :: Number -> Map -> Number -> IO Map
decrease p m k = do
v <- sumOfSieved m k p
mupdate m k (subtract v)
return m
sumOfSieved :: Map -> Number -> Number -> IO Number
sumOfSieved m v p = do
v1 <- m ! (v `div` p)
v2 <- m ! (p - 1)
return $ p * (v1 - v2)
main = do { n <- problem10 (2*10^9) ; print n; } -- 2*10^9
Я не профессионал с хэшированием и такими вещами, поэтому
это, безусловно, может быть улучшено. Может быть, мы, Хаскеллеры, должны
улучшить хеш-карты полки или предоставить несколько более простых.
My hashmap, код Shershs
Если я подключу свой хэш файл в коде Shershs (см. ниже), см. здесь
мы даже до
time stack exec prim-hash2
95673602693282040
real 0m0.601s
user 0m0.604s
sys 0m0.034s
Почему slow.hs медленно?
Если вы читаете источник
для функции insert
в Data.HashTable.ST.Basic
, вы
увидит, что он удаляет старую пару значений ключа и вставляет
новенький. Он не ищет "место" для ценности и
мутировать его, как можно себе представить, если прочитать, что это
"изменчивая" хэш-таблица. Здесь хэш-таблица является изменчивой,
поэтому вам не нужно копировать всю хэш-таблицу для вставки
новой пары значений ключа, но значения для пар
не. Я не знаю, если это вся история slow.hs
будучи медленным, но я предполагаю, что это довольно большая его часть.
Несколько незначительных улучшений
Итак, идея, которую я выполнял, пытаясь улучшить
вашей программы в первый раз.
См., вам не требуется измененное сопоставление от ключей к значениям.
Ваш набор ключей исправлен. Вы хотите, чтобы сопоставление ключей с изменяемыми
мест. (Это, кстати, то, что вы получаете от С++ по умолчанию.)
И поэтому я попытался придумать это. Я использовал IntMap IORef
из
Data.IntMap.Strict
и Data.IORef
сначала и получили время
из
tack exec prim
95673602693282040
real 0m2.134s
user 0m2.141s
sys 0m0.028s
Я подумал, может быть, это поможет работать с unboxed values
и для этого я использовал IOUArray Int Int
с 1 элементом
каждый вместо IORef
и получил эти тайминги:
time stack exec prim
95673602693282040
real 0m2.015s
user 0m2.018s
sys 0m0.038s
Не большая разница, и поэтому я попытался избавиться от границ
проверка в 1 массиве элементов с помощью unsafeRead
и
unsafeWrite
и получил время
time stack exec prim
95673602693282040
real 0m1.845s
user 0m1.850s
sys 0m0.030s
который был лучшим, я использовал Data.IntMap.Strict
.
Конечно, я запускал каждую программу несколько раз, чтобы узнать,
времена стабильны и различия во времени выполнения не
просто шум.
Похоже, что это всего лишь микро-оптимизация.
И вот программа, которая работает быстрее для меня, не используя ручную структуру данных:
import qualified Data.IntMap.Strict as M
import Control.Monad
import Data.Array.IO
import Data.Array.Base (unsafeRead, unsafeWrite)
type Number = Int
type Place = IOUArray Number Number
type Map = M.IntMap Place
tupleToRef :: (Number, Number) -> IO (Number, Place)
tupleToRef = traverse (newArray (0,0))
insertRefs :: [(Number, Number)] -> IO [(Number, Place)]
insertRefs = traverse tupleToRef
fromList :: [(Number, Number)] -> IO Map
fromList xs = M.fromList <$> insertRefs xs
(!) :: Map -> Number -> IO Number
(!) m i = unsafeRead (m M.! i) 0
mupdate :: Map -> Number -> (Number -> Number) -> IO ()
mupdate m i f = do
let place = m M.! i
old <- unsafeRead place 0
let x' = f old
-- make the application of f strict
x' `seq` unsafeWrite place 0 x'
r' :: Number -> Number
r' = floor . sqrt . fromIntegral
vs' :: Integral a => a -> a -> [a]
vs' n r = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]
list' :: [Number] -> IO Map
list' vs = fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]
problem10 :: Number -> IO Number
problem10 n = do
m <- list' vs
nm <- sieve m 2 r vs
nm ! n
where vs = vs' n r
r = r' n
sieve :: Map -> Number -> Number -> [Number] -> IO Map
sieve m p r vs | p > r = return m
| otherwise = do
v1 <- m ! p
v2 <- m ! (p - 1)
nm <- if v1 > v2 then update m vs p else return m
sieve nm (p + 1) r vs
update :: Map -> [Number] -> Number -> IO Map
update m vs p = foldM (decrease p) m $ takeWhile (>= p*p) vs
decrease :: Number -> Map -> Number -> IO Map
decrease p m k = do
v <- sumOfSieved m k p
mupdate m k (subtract v)
return m
sumOfSieved :: Map -> Number -> Number -> IO Number
sumOfSieved m v p = do
v1 <- m ! (v `div` p)
v2 <- m ! (p - 1)
return $ p * (v1 - v2)
main = do { n <- problem10 (2*10^9) ; print n; } -- 2*10^9
Если вы прокомментируете это, вы увидите, что большую часть времени он проводит в пользовательской функции поиска (!)
,
не знаю, как улучшить это дальше. Попытка встраивать (!)
с помощью {-# INLINE (!) #-}
не дали лучших результатов; возможно, ghc уже сделал это.
Ответ 3
Этот мой код оценивает сумму до 2⋅10 ^ 9 за 0,3 секунды и сумму до 10 ^ 12 (18435588552550705911377) за 19,6 секунды (при наличии достаточного количества ОЗУ).
import Control.DeepSeq
import qualified Control.Monad as ControlMonad
import qualified Data.Array as Array
import qualified Data.Array.ST as ArrayST
import qualified Data.Array.Base as ArrayBase
primeLucy :: (Integer -> Integer) -> (Integer -> Integer) -> Integer -> (Integer->Integer)
primeLucy f sf n = g
where
r = fromIntegral $ integerSquareRoot n
ni = fromIntegral n
loop from to c = let go i = ControlMonad.when (to<=i) (c i >> go (i-1)) in go from
k = ArrayST.runSTArray $ do
k <- ArrayST.newListArray (-r,r) $ force $
[sf (div n (toInteger i)) - sf 1|i<-[r,r-1..1]] ++
[0] ++
[sf (toInteger i) - sf 1|i<-[1..r]]
ControlMonad.forM_ (takeWhile (<=r) primes) $ \p -> do
l <- ArrayST.readArray k (p-1)
let q = force $ f (toInteger p)
let adjust = \i j -> do { v <- ArrayBase.unsafeRead k (i+r); w <- ArrayBase.unsafeRead k (j+r); ArrayBase.unsafeWrite k (i+r) $!! v+q*(l-w) }
loop (-1) (-div r p) $ \i -> adjust i (i*p)
loop (-div r p-1) (-min r (div ni (p*p))) $ \i -> adjust i (div (-ni) (i*p))
loop r (p*p) $ \i -> adjust i (div i p)
return k
g :: Integer -> Integer
g m
| m >= 1 && m <= integerSquareRoot n = k Array.! (fromIntegral m)
| m >= integerSquareRoot n && m <= n && div n (div n m)==m = k Array.! (fromIntegral (negate (div n m)))
| otherwise = error $ "Function not precalculated for value " ++ show m
primeSum :: Integer -> Integer
primeSum n = (primeLucy id (\m -> div (m*m+m) 2) n) n
Если ваша функция integerSquareRoot
глючит (как сообщается, некоторые из них), вы можете заменить ее здесь floor . sqrt . fromIntegral
.
Пояснение:
Как следует из названия, оно основано на обобщении известного метода "Люси-Ежик", который в конечном итоге был обнаружен оригинальным плакатом.
Он позволяет рассчитать много сумм формы
(с p prime) без перечисления всех простых чисел до N и во времени O (N ^ 0,75).
Его входы - это функция f (т.е. id
, если вы хотите получить первую сумму), ее суммирующая функция по всем целым числам (т.е. в этом случае сумма первых m целых чисел или div (m*m+m) 2
) и N.
PrimeLucy
возвращает функцию поиска
(с p prime) ограничено определенными значениями n:
.
Ответ 4
Попробуйте это и сообщите мне, насколько это быстро:
-- sum of primes
import Control.Monad (forM_, when)
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Unboxed
sieve :: Int -> UArray Int Bool
sieve n = runSTUArray $ do
let m = (n-1) `div` 2
r = floor . sqrt $ fromIntegral n
bits <- newArray (0, m-1) True
forM_ [0 .. r `div` 2 - 1] $ \i -> do
isPrime <- readArray bits i
when isPrime $ do
let a = 2*i*i + 6*i + 3
b = 2*i*i + 8*i + 6
forM_ [a, b .. (m-1)] $ \j -> do
writeArray bits j False
return bits
primes :: Int -> [Int]
primes n = 2 : [2*i+3 | (i, True) <- assocs $ sieve n]
main = do
print $ sum $ primes 1000000
Вы можете запустить его на ideone. Моим алгоритмом является Сито Эратосфена, и оно должно быть довольно быстрым для малых n. При n = 2 000 000 000 размер массива может быть проблемой, и в этом случае вам нужно будет использовать сегментированное сито. См. мой блог для получения дополнительной информации о сите Эратосфена. См. этот ответ для информации о сегментированном сите (но не в Haskell, к сожалению).