Haskellでのメモ化?


136

Haskellで次の関数を効率的に解決する方法に関する任意のポインタ(多数の場合) (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

Haskellでフィボナッチ数を解決するためのメモ化の例を見てきました。これには、必要なnまでのすべてのフィボナッチ数を(遅延的に)計算することが含まれます。しかし、この場合、与えられたnに対して、必要な中間結果はごくわずかです。

ありがとう


110
私が自宅でやっていることは、いくつかの作業であるという意味でのみです:-)
Angel de Vicente

回答:


256

これは、線形以下の時間でインデックスを付けることができる構造を作成することにより、非常に効率的に行うことができます。

でもまず、

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

を定義してみましょうf。ただし、それ自体を直接呼び出すのではなく、「オープン再帰」を使用するようにします。

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

fを使用すると、メモをとることができますfix f

これにより、次のように呼び出して、のf小さい値に対して何を意味するかをテストできますffix f 123 = 144

これを次のように定義することでメモできます。

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

これはかなりうまく機能し、O(n ^ 3)の時間がかかるものを中間結果を記念するものに置き換えます。

しかし、のメモされた答えを見つけるためにインデックスを作成するだけでも、線形時間はかかりますmf。つまり、次のような結果になります。

*Main Data.List> faster_f 123801
248604

許容範囲内ですが、結果はそれよりもはるかによくスケーリングされません。私たちはもっとうまくやれる!

まず、無限ツリーを定義しましょう:

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

次に、それにインデックスを付ける方法を定義します。これにより、代わりnO(log n)時間でインデックスを持つノードを見つけることができます。

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

...そして、これらのインデックスをいじる必要がないように、自然数でいっぱいのツリーが便利であることがわかる場合があります。

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

インデックスを作成できるので、ツリーをリストに変換できます。

toList :: Tree a -> [a]
toList as = map (index as) [0..]

あなたがそれtoList natsを与えることを確認することによって、これまでの作業を確認できます[0..]

さて、

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

上記のリストと同じように機能しますが、各ノードを見つけるために線形時間をとる代わりに、対数時間で追跡することができます。

結果はかなり速くなります:

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

実際、非常に高速であるためIntInteger上記の手順を実行して置き換え、途方もなく大きな答えをほぼ瞬時に得ることができます。

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358

3
私はこのコードを試しましたが、興味深いことに、f_fasterはfよりも遅いようです。これらのリスト参照は本当に物事を遅くしたと思います。natsとindexの定義は私にはかなり不思議に思えたので、物事をより明確にする可能性のある独自の答えを追加しました。
Pitarou

5
無限リストの場合は、リンクされたリスト111111111アイテムを扱う必要があります。ツリーの場合は、log n *到達したノードの数を扱います。
Edward KMETT 2013

2
つまり、リストバージョンでは、リスト内のすべてのノードのサンクを作成する必要がありますが、ツリーバージョンでは、サンクを多数作成する必要がありません。
トム・エリス

7
私はこれがかなり古い投稿であることを知っていますが、呼び出し間でツリー内の不要なパスを保存f_treeしないように、where句で定義するべきではありませんか?
dfeuer 2014

17
CAFにそれを詰め込む理由は、呼び出し全体でメモを取得できるためです。覚えていた高額な通話があった場合は、おそらくCAFに残しておくことになるため、ここに示す手法を使用します。もちろん、実際のアプリケーションでは、永続的なメモ化のメリットとコストの間にトレードオフがあります。しかし、どのようにしてメモ化を達成するかという質問があったとしても、意図的に通話を超えてメモ化を回避する手法で答えるのは誤解を招くと思います。;)
Edward KMETT 2014

17

エドワードの答えはとても素晴らしいので、私はそれを複製し、関数をオープン再帰形式でメモする実装memoListmemoTreeコンビネーターを提供しました。

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f

12

最も効率的な方法ではありませんが、覚えておいてください:

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

をリクエストすると、存在f !! 144が確認されf !! 143ますが、正確な値は計算されません。それはまだ未知の計算結果として設定されています。計算される正確な値は、必要なものだけです。

したがって、最初は、計算された限り、プログラムは何も知りません。

f = .... 

リクエストを行うと、f !! 12パターンマッチングが開始されます。

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

今計算を開始します

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

これは再帰的にfに別の要求を出すので、

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

今、私たちはいくつかをトリクルバックすることができます

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

つまり、プログラムは次のことを認識します。

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

トリクルアップを続けます。

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

つまり、プログラムは次のことを認識します。

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

次に、次の計算を続けますf!!6

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

つまり、プログラムは次のことを認識します。

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

次に、次の計算を続けますf!!12

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

つまり、プログラムは次のことを認識します。

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

したがって、計算はかなり遅延して行われます。プログラムは、の値がf !! 8存在すること、それがに等しいことを知っていますが、何であるかはわかりg 8ませんg 8


これをありがとう。2次元の解空間をどのように作成して使用しますか?それはリストのリストでしょうか?そしてg n m = (something with) f!!a!!b
vikingsteve 2014年

1
もちろん、できます。真の解決策について、けれども、私はおそらくのように、メモ化ライブラリを使用したいmemocombinators
ホタルブクロ

残念ながらO(n ^ 2)です。
数値による

8

これは、エドワード・クメットの優れた答えの補遺です。

私が彼のコードを試したとき、その定義natsindexかなり不思議に思えたので、理解しやすい別のバージョンを書きました。

私が定義するindexnatsの観点index'nats'

index' t nはの範囲で定義されます[1..]。(index t範囲全体で定義されていることを思い出してください[0..]。)これは、nビットの文字列として扱い、ビットを逆に読み取ることによってツリーを検索します。ビットがの場合、1右側の分岐が行われます。ビットがの場合、0左側の分岐が行われます。最後のビット(これは1)に達すると停止します。

index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

同じようnatsに定義されているindexので、index nats n == n常に真である、nats'のために定義されますindex'

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

さて、natsindex単純さnats'index'が、値を1ずつシフト:

index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'

ありがとう。私は多変量関数を覚えています。これは、インデックスとnatが実際に行っていることを理解するのに本当に役立ちました。
Kittsil

8

Edward Kmettの回答で述べたように、処理を高速化するには、コストのかかる計算をキャッシュし、それらにすばやくアクセスできるようにする必要があります。

関数を非モナドに保つには、無限の遅延ツリーを構築するソリューション(以前の投稿で示したように)にインデックスを付ける適切な方法を使用して、その目標を達成します。関数の非モナド性をあきらめる場合は、Haskellで利用可能な標準の連想コンテナを「ステートのような」モナド(ステートやSTなど)と組み合わせて使用​​できます。

主な欠点は非モナド関数を取得することですが、構造体に自分でインデックスを付ける必要はなく、連想コンテナの標準実装を使用することができます。

これを行うには、まず、あらゆる種類のモナドを受け入れるように関数を書き直す必要があります。

fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _    0 = return 0
fm recf n = do
   recs <- mapM recf $ div n <$> [2, 3, 4]
   return $ max n (sum recs)

テストでは、少し冗長ですが、Data.Function.fixを使用してメモ化しない関数を定義することもできます。

noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm

次に、StateモナドをData.Mapと組み合わせて使用​​して、処理を高速化できます。

import qualified Data.Map.Strict as MS

withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
   where
      recF i = do
         v <- MS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ MS.insert i v'
               return v'

マイナーな変更により、代わりにコードをData.HashMapで動作するように調整できます。

import qualified Data.HashMap.Strict as HMS

withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
   where
      recF i = do
         v <- HMS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ HMS.insert i v'
               return v'

永続的なデータ構造の代わりに、変更可能なデータ構造(Data.HashTableなど)をSTモナドと組み合わせて試すこともできます。

import qualified Data.HashTable.ST.Linear as MHM

withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
   do ht <- MHM.new
      recF ht n
   where
      recF ht i = do
         k <- MHM.lookup ht i
         case k of
            Just k' -> return k'
            Nothing -> do 
               k' <- fm (recF ht) i
               MHM.insert ht i k'
               return k'

メモ化のない実装と比較して、これらの実装のいずれかを使用すると、巨大な入力に対して、数秒待つ必要がなく、マイクロ秒単位で結果を取得できます。

Criterionをベンチマークとして使用すると、Data.HashMapを使用した実装は、タイミングが非常に似ていたData.MapおよびData.HashTableよりもわずかに(約20%)パフォーマンスが高いことがわかりました。

ベンチマークの結果は少し驚くべきものでした。私の最初の感想は、HashTableは変更可能であるため、HashMapの実装よりも優れているということでした。この最後の実装では、パフォーマンスの欠陥が隠れている可能性があります。


2
GHCは不変の構造を中心に最適化する非常に良い仕事をします。Cの直感は、いつもうまくいくとは限りません。
John Tyree、2015年

3

数年後、私はこれを見てzipWith、ヘルパー関数を使用して線形時間でこれをメモする簡単な方法があることに気づきました:

dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs

dilateという便利な特性がありdilate n xs !! i == xs !! div i nます。

したがって、f(0)が与えられているとすると、これにより計算が簡単になります。

fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
  where (.+.) = zipWith (+)
        infixl 6 .+.
        (#/) = flip dilate
        infixl 7 #/

元の問題の説明とよく似ており、線形解を与えます(sum $ take n fsO(n)を使用します)。


2
したがって、これは生成的(共起的?)、または動的プログラミングのソリューションです。通常のフィボナッチが行っているように、生成された各値ごとにO(1)時間を取ります。すごい!そして、EKMETTのソリューションは、対数の大きなフィボナッチのようなもので、大きな数値にはるかに速く到達し、中間の多くをスキップします。これで大丈夫ですか?
ネスは

あるいは、生成されるシーケンスへの3つのバックポインターと、それに沿って進むそれぞれの異なる速度で、ハミング数の方が近いかもしれません。すごく可愛い。
ネスは

2

エドワード・クメットの答えのさらに別の補足:自己完結型の例:

data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)

memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
  where nats = go 0 1
        go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
          where s' = 2*s
        index (NatTrie l v r) i
          | i <  0    = f (index_to_arg i)
          | i == 0    = v
          | otherwise = case (i-1) `divMod` 2 of
             (i',0) -> index l i'
             (i',1) -> index r i'

memoNat = memo1 id id 

これを次のように使用して、単一の整数引数を持つ関数をメモします(例:フィボナッチ):

fib = memoNat f
  where f 0 = 0
        f 1 = 1
        f n = fib (n-1) + fib (n-2)

負でない引数の値のみがキャッシュされます。

負の引数の値もキャッシュするmemoIntには、次のように定義されたを使用します。

memoInt = memo1 arg_to_index index_to_arg
  where arg_to_index n
         | n < 0     = -2*n
         | otherwise =  2*n + 1
        index_to_arg i = case i `divMod` 2 of
           (n,0) -> -n
           (n,1) ->  n

2つの整数引数を持つ関数の値をキャッシュするmemoIntIntには、次のように定義されたを使用します。

memoIntInt f = memoInt (\n -> memoInt (f n))

2

インデックスなしのソリューションであり、Edward KMETTに基づいていません。

共通のサブツリーを共通の親に分解します(およびのf(n/4)f(n/2)で共有されf(n/4)、およびのf(n/6)f(2)で共有されますf(3))。それらを親の単一変数として保存することにより、サブツリーの計算が1回行われます。

data Tree a =
  Node {datum :: a, child2 :: Tree a, child3 :: Tree a}

f :: Int -> Int
f n = datum root
  where root = f' n Nothing Nothing


-- Pass in the arg
  -- and this node's lifted children (if any).
f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
f' 0 _ _ = leaf
    where leaf = Node 0 leaf leaf
f' n m2 m3 = Node d c2 c3
  where
    d = if n < 12 then n
            else max n (d2 + d3 + d4)
    [n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
    [d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
    c2 = case m2 of    -- Check for a passed-in subtree before recursing.
      Just c2' -> c2'
      Nothing -> f' n2 Nothing (Just c6)
    c3 = case m3 of
      Just c3' -> c3'
      Nothing -> f' n3 (Just c6) Nothing
    c4 = child2 c2
    c6 = f' n6 Nothing Nothing

    main =
      print (f 123801)
      -- Should print 248604.

コードは一般的なメモ化関数に簡単に拡張できません(少なくとも、それを行う方法がわからないでしょう)。サブ問題がどのようにオーバーラップするかを実際に考える必要がありますが、戦略は一般的な複数の非整数パラメーターに対して機能する必要があります。(2つの文字列パラメーターについて考えました。)

メモは各計算後に破棄されます。(ここでも、2つの文字列パラメーターについて考えていました。)

これが他の回答よりも効率的かどうかはわかりません。各ルックアップは技術的には1つまたは2つのステップ(「子供またはあなたの子供を見る」)のみですが、多くの余分なメモリを使用する可能性があります。

編集:このソリューションはまだ正しくありません。共有は不完全です。

編集:それは正しくsubchildrenを共有する必要がありますが、私はこの問題は自明でない共有をたくさん持っていることに気づいたn/2/2/2n/3/3同じかもしれません。問題は私の戦略によく合いません。

弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.