Python で並列計算 (multiprocessing モジュール) | 複数の引数を取る関数を map() メソッドで並列に走らせる

いまやノートブックでもデュアルコアクアッドコアが当たり前になってきたので、似たような処理を延々と繰り返すようなデータ解析のプログラムなどは並列化するとかなりの恩恵が得られる。Python ではバージョン2.6から multiprocessing モジュールというのが標準ライブラリに入っており、比較的簡単に並列計算のスクリプトが書けてしまう。

個人的に並列化できると一番うれしいのは、統計量の有意性の検定のためにサロゲートデータを大量に作って、そこから(帰無仮説に基づいた)統計量の分布を推定するときの、それぞれのサロゲートごとの統計量の計算だ。この場合、個々の計算は完全に独立なので、プロセス間の通信などは考える必要がなく、非常に単純に並列化できる。ということで早速スクリプトを書いてみたが、まあ書くだけならすぐ書けるのだが、使い勝手や汎用性を求めると、ちょっとややこしいことを考えなければならなくなってしまった。

とりあえずの要件として、求めたい統計量の計算をするための関数はあらかじめ手元にあるので、これをそのまま使って並列化したいということがある。つまり、並列化のためだけにわざわざ新しい関数を書くというようなめんどくさいことはしたくない。multiprocessing モジュールで並列化する場合、一番単純なのは Process クラスのオブジェクトに並列で計算したい関数を設定して(multiprocessing.Process(func, args))、それを start() メソッドで走らせるというやり方だが、これだと関数の返り値を受け取るのに、親プロセスと子プロセスの間の通信を使わないといけない。そのためにはこれらのプロセスの間にパイプを通して(multiprocessing.Pipe())、子プロセスの関数にこのパイプ経由で計算結果を出力するようにさせないといけないのだが、これには当然関数の書き換えが必要となってくる。

何かもっと楽な方法はないかと調べてみると、Process のかわりに Pool というのを使うと、返り値を楽に取り扱えることが分かった。すなわち、results = multiprocessing.Pool().map(func, arglist) とすると、arglist 内の要素のそれぞれを引数とした func を並列に計算し、その結果をリストとして results に格納してくれる。同時に走るプロセス数は、マシンのコア数を調べて勝手にそれに合わせてくれる(Pool()の引数で指定することも可)。これはとても楽だ。

が、実際に自分の関数を使って並列化しようとしてみると、困ったことに気付いた。統計量の計算をするとき、しばしばいくつかのパラメータを設定する必要があるのだが、手元の関数では、これを複数の引数を渡すことで実装している。ところが、Pool オブジェクトの map() メソッドでは、関数に複数の引数を渡すことができない。リストにして渡せばひとつの引数にまとめられるが、そうすると再び関数の書き換えが必要になってしまう。これは困った・・・

ということで、なんとかして既存の関数を書きかえることなく、複数の引数を渡せるようにしようといろいろ試してみた。いまのところ、以下のようなラッパー関数を使うのが一番簡単で、かつ使いやすいんではないかというところに落ち着いた。

def argwrapper(args):
    '''
    ラッパー関数
    '''
    return args[0](*args[1:])

def myfunc(a, b):
    '''
    並列に計算したい関数
    '''
    return a*b

if __name__ == '__main__':
    from multiprocessing import Pool
    p = Pool()
    func_args = []
    for a in xrange(1,10):
        for b in xrange(1,10):
            func_args.append( (myfunc, a, b) )
    results = p.map(argwrapper, func_args)
    print results

argwrapper() はリストを引数に取り、このリストの最初の要素を関数として実行し、それ以降の要素をその関数に渡す引数とする。main 部分では、計算したい関数とそれへの引数をまとめたリストを作り、それを map() メソッドで argwrapper() に渡している。この方法なら、あらゆる既存の関数に関して、複数の引数を与え map() メソッドで並列に計算させることができるはず。

複数の引数を与えるだけなら map() メソッドのかわりに apply_async() メソッドを使ってもいいのだが、その場合、結果をひとつづつ get() メソッドで取り出さないといけない。結果が直接リストとして返ってくる上記の方法のほうが、個人的には使い勝手がいいと思う。