问题描述
我试图更好地理解numba装饰器,尤其是guvectorize。
我试图修改它以计算风速。
这就是我得到的:
import numpy as np
import xarray as xr
import datetime
import glob
import dask
import sys
import os
import tempfile
from numba import float64,guvectorize,vectorize,njit
import time as t
@guvectorize(
"(float64,float64,float64)","(),() -> ()",nopython=True,)
def calcWindspeed_ufunc(u,v,out):
out = np.sqrt( u**2 + v**2 )
def calcWindspeed(u,v):
return xr.apply_ufunc(calcWindspeed_ufunc,u,input_core_dims=[[],[]],output_core_dims=[[]],# vectorize=True,dask="parallelized",output_dtypes=[u.dtype])
def main():
nlon = 120
nlat = 100
ntime = 3650
lon = np.linspace(129.4,153.75,nlon)
lat = np.linspace(-43.75,-10.1,nlat)
time = np.linspace(0,365,ntime)
#< Create random data
u = 10 * np.random.rand(len(time),len(lat),len(lon))
u = xr.Dataset({"u": (["time","lat","lon"],u)},coords={"time": time,"lon": lon,"lat": lat})
u = u.chunk({'time':365})
u = u['u']
v = u.copy()
start = t.time()
ws_xr = np.sqrt( u**2 + v**2 ).load()
end = t.time()
print('It took xarray {} seconds!'.format(end-start))
start = t.time()
ws_ufunc = calcWindspeed(u,v).load()
end = t.time()
print('It took numba {} seconds!'.format(end-start))
# Difference of the output
print( (ws_xr-ws_ufunc).max() )
if __name__ == '__main__':
import dask.distributed
import sys
# Get the number of cpuS in the job and start a dask.distributed cluster
mem = 190
cores = 4
memory_limit = '{}gb'.format(int(max(mem/cores,4)))
client = dask.distributed.Client(n_workers=cores,threads_per_worker=1,memory_limit=memory_limit,local_dir=tempfile.mkdtemp())
#< Print client summary
print('### Client summary')
print(client)
print('\n\n')
#< Call the main function
main()
#< Close the client
client.shutdown()
这在技术上有效(运行),但是输出错误。两次计算之间的差应接近0,但以我为例。
我不明白自己在做什么错。
谢谢您的帮助!
解决方法
一些想法:
- 如果仅调用numpy,则无需使用numba。 Numba运行已编译的代码,但是当前示例实际上没有任何代码...
- 如果您要使用它在多个维度上运行,则可以单独使用
xr.apply_ufunc
来做到这一点 - 如果您希望其他人参与该示例,可以将其缩小到最小尺寸吗?当前有dask,xarray和numba -如果您将其删除,diff是否仍然有效?
作为参考,以下是一些我使用xarray和numba https://github.com/shoyer/numbagg/blob/master/numbagg/moving.py
编写的函数