Mandelbrot Numba/Numpy 矢量化?

问题描述

我使用 kivy 在 Python 中编写了一个交互式 mandelbrot 渲染器,您可以在其中使用鼠标指针进行缩放,并正在尽我所能对其进行优化。我目前使用这个实现来渲染设置/缩放(这是一个小片段,只是用来渲染它的两个函数):

import numba as nb
import numpy as np


@nb.njit(cache= True,parallel = True)
def mandelbrot(c_r,c_i,maxIt): #mandelbrot function
        z_r = 0 
        z_i = 0
        z_r2 = 0
        z_i2= 0
        for x in nb.prange(maxIt):
            z_i = 2 * z_r * z_i + c_i
            z_r = z_r2 - z_i2 + c_r
            z_r2 = z_r * z_r
            z_i2 = z_i * z_i
            if z_r2 + z_i2 > 4:
                return x
        return maxIt

@nb.njit(cache= True,parallel = True)
def DrawSet(W,H,xStart,xdist,yStart,ydist,maxIt):
        array = np.zeros((H,W,3),dtype=np.uint8) #array that holds 'hsv' tuple for every pixel
        for x in nb.prange(0,W):
            c_r = (x/W)* xdist + xStart #some math to calculate real part
            for y in range (0,H):
                c_i = -((y/H) * ydist + yStart) #some more math to calculate imaginary part
                cIt = mandelbrot(c_r,maxIt) 
                color = int((255 * cIt) / maxIt)
                array[y,x] = (color,255,255) #adds hue value 
        return array #returns hsv array,gets later displayed using PIL

我目前的表现相当不错。它可以在大约 0.08 - 0.09 秒内渲染一个 500x500 的区域,其中每个点都有界(所以基本上是黑色图片,最坏的情况),迭代 300 次。我将 Numba JIT 与并行范围函数“prange()”一起使用,这有很大帮助。

但是,我听说矢量化通常是渲染此类分形的最快方法。经过大量研究(我对矢量化很陌生),我设法将这个实现放在一起:

import numba as nb
import numpy as np

def DrawSet(W,xEnd,yEnd,maxIt):

    array = np.zeros((H,dtype = np.uint8) # 3D array containing 'hsv' tuple (hue,saturation,value) of each pixel

    x = np.linspace(xStart,W).reshape((1,W)) #scaling horizontal pixels to x-axis
    y = np.linspace(yStart,H).reshape((H,1)) #scaling vertical pixels to y-axis
    c = x + 1j * y #creating complex plane out of x axis (real) and y axis (imaginary)
    z = np.zeros(c.shape,dtype= np.complex128)
    div_time = np.zeros(z.shape,dtype= int)
    m = np.full(c.shape,True,dtype= bool)

    div_time = loop(z,c,div_time,m,maxIt)
    
    array[:,:,0] = (div_time/maxIt) * 255 -20 #adding 'hue' value
    array[:,1] = 255 #adding 'saturation' value
    array[:,2] = 255 #adding 'value'
    
    return array


@nb.vectorize(nb.int64[:,:](nb.complex128[:,:],nb.complex128[:,nb.int64[:,nb.boolean[:,nb.int64))
def loop(z,maxIt):

    for i in range(maxIt):
        z[m] = z[m]**2 + c[m]
        diverged = np.greater(np.abs(z),2,out=np.full(c.shape,False),where=m)
        div_time[diverged] = i      
        m[np.abs(z) > 2] = False
    return div_time

没有@nb.vectorize 装饰器,它运行得非常慢。 (500x500 的最坏情况为 4 秒,300 It。)。使用 @nb.vectorize 装饰器,我收到此错误


Traceback (most recent call last):
   File "Mandelbrot.py",line 13,in <module>
     from test import DrawSet
   File "C:\Users\User\Documents\Code\Python\Mandelbrot-GUI\test.py",line 25,in <module>
     def loop(z,maxIt):
   File "C:\Users\User\AppData\Local\Programs\Python\python38\lib\site-packages\numba\np\ufunc\decorators.py",line 119,in wrap
     for sig in ftylist:
 TypeError: 'Signature' object is not iterable

我做错了什么?我是否以正确的方式定义了所有的 numba 签名? 这种矢量化方法会超过我当前的实现吗?

我会感谢每一个建议!提前致谢。

解决方法

您的实现已经矢量化了!

矢量化的想法是创建 universal functions 对数组进行元素操作。您只需定义对单个元素执行的操作,向量化机制将允许使用数组调用您的函数。

该函数计算单个点 c:

def mandelbrot_point(c,max_it):
    z = 0j
    for i in range(max_it):
        z = z**2 + c
        if abs(z) > 2:
            return i
    return 0

您可以使用 Numpy 对其进行矢量化:

@np.vectorize
def mandelbrot_numpy(c,max_it):
    z = 0j
    for i in range(max_it):
        z = z**2 + c
        if abs(z) > 2:
            return i
    return 0

或者您可以使用 Numba 对其进行矢量化。请注意,函数的签名描述了如何处理单个点:

@nb.vectorize([nb.int64(nb.complex128,nb.int64)])
def mandelbrot_numba(c,max_it):
    z = 0j
    for i in range(max_it):
        z = z**2 + c
        if abs(z) > 2:
            return i
    return 0

然后您可以使用任意维数的标量或数组调用向量化函数:

>>> p = 0.4+0.4j
>>> mandelbrot_point(p,99)
8
>>> mandelbrot_numpy(p,99)
array(8)
>>> mandelbrot_numba(p,99)
8

>>> x = np.linspace(-2,2,11)
>>> mandelbrot_numpy(x,99)
array([0,6,1,1])
>>> mandelbrot_numba(x,1])

>>> x = np.atleast_2d(x)
>>> y = x.T
>>> c = x + 1j * y
>>> mandelbrot_numpy(c,99)
array([[ 0,0],[ 0,3,5,17,8,1],0]])
>>> mandelbrot_numba(c,0]])

Numpy 的 vectorize 极大地简化了您的代码,但正如文档所说,它主要是为了方便,而不是为了性能。实现本质上是一个 for 循环。

根据我的测量,Numpy 向量化版本比您的原始实现略快,而 Numba 向量化版本快一个数量级。