问题描述
我正在编写代码来可视化 Mandelbrot 集和其他分形。下面是正在运行的代码片段。代码按原样运行得很好,但我正在尝试优化它以更快地制作更高分辨率的图像。我曾尝试在 fractal()
上使用缓存,以及来自 Numba 的 @jit
和 @njit
。缓存导致崩溃(我假设是内存溢出),@jit
只是将我的程序的执行速度减慢了 6 倍。我也知道有许多数学方法可以让我的代码运行更快,正如我在维基百科页面上看到的那样,但我想看看我是否可以获得上述方法之一或其他一些替代方法。
为了连续创建多个图像(制作缩放动画,就像这个)我已经实现了多处理(似乎一次运行 9 个进程)但我不知道如何在创建中实现相同的单个高分辨率图像。
这是我的代码片段:
import numpy as np
import cv2
import cmath
import math
# pick the fractal
def fractal(z,c):
# Mandelbrot
if fractal_type == 0:
return z**d + c
# Burning Ship
if fractal_type == 1:
return complex(abs(z.real),abs(z.imag))**d + c
#naive escape time algorithm
def naive_escape(arr):
h = arr[0]
w = arr[1]
d = arr[2]
zoom = pow(1.5,arr[3]) * pow(10,int(np.log10(h)))
x_cen = arr[4]
y_cen = arr[5]
for i in range(w):
sys.stdout.write("\r{0:03}%".format(np.round(i/w * 100,4)))
sys.stdout.flush()
for j in range(h):
it = 0
#coordinates
cx = i - int(w/2)
cy = j - int(h/2)
#scaling
sx = (cx / (zoom)) + x_cen
sy = (cy / (zoom)) - y_cen
c = complex(sx,sy)
z = complex(0.0,0.0)
while ((z.real)**2 + (z.imag)**2 <= 2**d) and (it < max_it):
z = fractal(z,c)
it += 1
img[j][i] = color_dict[it]
sys.stdout.write("\n")
name = "fractal"
cv2.imwrite("{}.png".format(name),img)
print("\n{} created!\n".format(name),fractal_type)
我应该澄清一下,着色函数 naive_escape()
接受数组输入的原因是因为我实现了多处理。由于多处理中的 map()
只允许我们用一个输入映射函数,所以我只传递一个包含所有输入值的数组。
上面粘贴的代码是来自一个更大文件的片段,所以请原谅任何语法错误。
任何有助于加快我的代码速度的帮助将不胜感激!
解决方法
This older answer 专门处理矢量化,但可以进行一些额外的优化。
你可以从 Numpy 向量化开始,方便但不是很快:
@np.vectorize
def mandelbrot_numpy(c: complex,max_it: int) -> int:
z = c
for i in range(max_it):
if abs(z) > 2:
return i
z = z**2 + c
return 0
或者 Numba 向量化,将速度提高一个数量级:
@nb.vectorize([nb.u2(nb.c16,nb.i8)])
def mandelbrot_numba(c: complex,max_it: int) -> int:
z = c
for i in range(max_it):
if abs(z) > 2:
return i
z = z**2 + c
return 0
然后你可以应用一些常用的优化:
@nb.vectorize([nb.u2(nb.c16,nb.u2)])
def mandelbrot_numba_opt(c: complex,max_it: int) -> int:
x = cx = c.real
y = cy = c.imag
for i in range(max_it):
x2 = x*x
y2 = y*y
if x2 + y2 > 4:
return i
y = (x+x)*y + cy
x = x2 - y2 + cx
return 0
你也可以并行化它(在这个例子中按行):
@nb.njit([nb.u2[:,:](nb.c16[:,:],nb.u2)],parallel=True)
def mandelbrot_parallel(c: np.ndarray,max_it: int) -> np.ndarray:
result = np.zeros_like(c,dtype=nb.u2)
for row in nb.prange(len(c)):
result[row] = mandelbrot_numba_opt(c[row],max_it)
return result
1000x1000 阵列上的一些计时:
N = 1000
x = np.linspace(-2,2,N).reshape((1,-1))
y = x.T
c = x + 1j * y
%timeit mandelbrot_numpy(c,99)
1.59 s ± 40.9 ms per loop (mean ± std. dev. of 7 runs,1 loop each)
%timeit mandelbrot_numba(c,99)
100 ms ± 406 µs per loop (mean ± std. dev. of 7 runs,10 loops each)
%timeit mandelbrot_numba_opt(c,99)
35 ms ± 140 µs per loop (mean ± std. dev. of 7 runs,10 loops each)
%timeit mandelbrot_parallel(c,99)
10.9 ms ± 64.3 µs per loop (mean ± std. dev. of 7 runs,100 loops each)