更新矢量化函数中的tqdm进度条

问题描述

我有一个具有以下呼叫签名的函数

import numpy as np
@np.vectorize
def evolve_system(a0,e0,beta,m1,m2,lt,pbar):
    ...
    pbar.update(1)
    ...

    return

并被这样称呼:

from tqdm import tqdm

with tqdm(total=len(df)) as pbar:
    n,m,ef,Pf,c = evolve_system(df['a0'].values,df['e0'].values,df['beta'].values,df['m1'].values,df['m2'].values,df['lifetime'].values,pbar
                                   )

其中df是熊猫DataFrame。运行代码后,出现以下回溯错误

Traceback (most recent call last):
  File "/home/sean/anaconda3/lib/python3.7/site-packages/julia/pseudo_python_cli.py",line 308,in main
    python(**vars(ns))
  File "/home/sean/anaconda3/lib/python3.7/site-packages/julia/pseudo_python_cli.py",line 59,in python
    scope = runpy.run_path(script,run_name="__main__")
  File "/home/sean/anaconda3/lib/python3.7/runpy.py",line 263,in run_path
    pkg_name=pkg_name,script_name=fname)
  File "/home/sean/anaconda3/lib/python3.7/runpy.py",line 96,in _run_module_code
    mod_name,mod_spec,pkg_name,script_name)
  File "/home/sean/anaconda3/lib/python3.7/runpy.py",line 85,in _run_code
    exec(code,run_globals)
  File "PhaseSpace.py",line 15,in <module>
    df_B = takahe.evolve.period_eccentricity(dataframes_Bray[Z].sample(1000))
  File "/home/sean/Documents/takahe/takahe/evolve.py",line 209,in period_eccentricity
    pbar
  File "/home/sean/.local/lib/python3.7/site-packages/numpy/lib/function_base.py",line 2108,in __call__
    return self._vectorize_call(func=func,args=vargs)
  File "/home/sean/.local/lib/python3.7/site-packages/numpy/lib/function_base.py",line 2198,in _vectorize_call
    for x,t in zip(outputs,otypes)])
  File "/home/sean/.local/lib/python3.7/site-packages/numpy/lib/function_base.py",in <listcomp>
    for x,otypes)])
ValueError: setting an array element with a sequence.

据我所知,这是由于pbar参数引起的-从定义和调用中省略它会导致代码运行。

有没有解决的办法?我可以在向量化函数调用pbar.update()吗?

解决方法

由于@hpaulj的建议-最简单的解决方案似乎是将pbar放在全局范围内,然后从那里使用它,即

from tqdm import tqdm
import numpy as np

@np.vectorize
def evolve_system(a0,e0,beta,m1,m2,lt):
    global pbar
    ...
    pbar.update(1)
    ...

    return

def main():
    global pbar
    ...
    with tqdm(total=len(df)) as pbar:
        n,m,ef,Pf,c = evolve_system(df['a0'].values,df['e0'].values,df['beta'].values,df['m1'].values,df['m2'].values,df['lifetime'].values,)