読者です 読者をやめる 読者になる 読者になる

人間だったら考えて

考えて考えて人間だったら考えて

Pythonのmultiprocessingを使ってマンデルブロ集合を並列に計算

この記事は何?

Pythonでプロセス並列を行いたい時,標準ライブラリにmultiprocessingが,外部ライブラリにjoblibがあります. 普段はjoblibを使って並列処理を書いているのですが,multiprocessingをあまり使ったことが無かったので触れてみた時のメモです.

対象の問題はマンデルブロ集合の計算としました,マンデルブロ集合の計算はembarrassingly parallel(単純に並列化可能)なので,並列化の練習にもってこいです.

ソースコード

最初にソースコードの全体を載せておきます.

import time
import numpy as np
import matplotlib.pyplot as plt
import multiprocessing as mp


def check_mandelbrot(a, b, n_loops=1000, threshold=10):
    """
    c = a + bi
    z_{n+1} = z_{n}^2 + c
    z_{0} = 0
    """
    z_a = 0
    z_b = 0
    for i in range(n_loops):
        z_a, z_b = z_a ** 2 - z_b ** 2 + a, 2 * z_a * z_b + b
        if z_a ** 2 + z_b ** 2 > threshold:
            return False
    return True


def sub_check_mandelbrot(p):
    first = p * step
    last = (p + 1) * step
    ans = []
    for ix in range(N):
        x = x_min + x_diff * ix
        for iy in range(first, last):
            y = y_min + y_diff * iy
            if check_mandelbrot(x, y):
                ans.append([x, y])
    return ans


def measure(n_para):
    pool = mp.Pool(n_para)
    s_time = time.time()
    cb = pool.map(sub_check_mandelbrot, range(n_para))
    e_time = time.time()
    print("time: {0:.3f}[s]".format(e_time - s_time))
    return cb


if __name__ == '__main__':
    x_min = -2
    x_max = 0.5
    y_min = -1
    y_max = 1
    N = 200
    n_para = 8
    x_diff = (x_max - x_min) / N
    y_diff = (y_max - y_min) / N
    step = N // n_para
    cb = measure(n_para)
    colors = ["m", "c", "y", "k"] * 2
    for i, c in enumerate(cb):
        c = np.array(c)
        plt.scatter(c[:, 0], c[:, 1], c=colors[i])
    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)
    plt.savefig("{}.png".format(n_para))

まずは入力がマンデルブロ集合に含まれるかどうかを求める関数を書きます.

def check_mandelbrot(a, b, n_loops=1000, threshold=10):
    """
    c = a + bi
    z_{n+1} = z_{n}^2 + c
    z_{0} = 0
    """
    z_a = 0
    z_b = 0
    for i in range(n_loops):
        z_a, z_b = z_a ** 2 - z_b ** 2 + a, 2 * z_a * z_b + b
        if z_a ** 2 + z_b ** 2 > threshold:
            return False
    return True

漸化式を1000回回して値が10を越えたら発散するとみなします.

次に,各プロセスが担当する入力範囲に対してマンデルブロ集合を求める関数を書きます.

def sub_check_mandelbrot(p):
    first = p * step
    last = (p + 1) * step
    ans = []
    for ix in range(N):
        x = x_min + x_diff * ix
        for iy in range(first, last):
            y = y_min + y_diff * iy
            if check_mandelbrot(x, y):
                ans.append([x, y])
    return ans

引数のpはプロセス番号を表します.今回はy軸で分割を行いました.

次に,プロセスを立ち上げて各プロセスにsub_check_mandelbrotを割り当てる関数を書きます.

def measure(n_para):
    pool = mp.Pool(n_para)
    s_time = time.time()
    cb = pool.map(sub_check_mandelbrot, range(n_para))
    e_time = time.time()
    print("time: {0:.3f}[s]".format(e_time - s_time))
    return cb

poolオブジェクトのmap関数に並列に実行したい関数とその引数を渡すと,処理が始まります.

最後にmain関数です.

if __name__ == '__main__':
    x_min = -2
    x_max = 0.5
    y_min = -1
    y_max = 1
    N = 200
    n_para = 8
    x_diff = (x_max - x_min) / N
    y_diff = (y_max - y_min) / N
    step = N // n_para
    cb = measure(n_para)
    colors = ["m", "c", "y", "k"] * 2
    for i, c in enumerate(cb):
        c = np.array(c)
        plt.scatter(c[:, 0], c[:, 1], c=colors[i])
    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)
    plt.savefig("{}.png".format(n_para))

n_paraの値が並列数を表します.結果が正しいことを確認するためにマンデルブロ集合を描画するようにしました.

実行結果

まずはマンデルブロ集合を描画していきます.

  • 並列数: 1 f:id:sz_dr:20160806172640p:plain

  • 並列数: 2 f:id:sz_dr:20160806172652p:plain

  • 並列数: 4 f:id:sz_dr:20160806172715p:plain

  • 並列数: 8 f:id:sz_dr:20160806172729p:plain

各プロセスの担当部分を塗り分けています. y=0に近い部分はマンデルブロ集合に含まれる点が多いため,この部分を担当するプロセスは他のプロセスよりも計算時間がかかることが予想されます.

次に計算時間を見ていきます.計算環境はVirtualBox上でプロセッサー数を4としました.

f:id:sz_dr:20160806173346p:plain

並列数1→2となる時は計算速度が倍になっていますが,それ以上の並列数では緩やかに速くなっていくことがわかります. マンデルブロ集合はy軸について対称なので,並列数1→2の時に計算速度が倍になったのだと考えられます.

その他

joblibの方が簡単だなあ…