问题描述
我们在Celery上尝试读取我们通过joblib保存的模型文件时遇到了一个非常奇怪的问题。这是使用工厂模式的Flask应用程序。这是整体设置:
fit_model.py
:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
import joblib
class MyClass(Pipeline):
pass
if __name__ == "__main__":
to_persist = MyClass(steps=[("step",OneHotEncoder())])
joblib.dump(to_persist,"dummy.model")
load_model.py
:
import joblib
if __name__ == "__main__":
with open("dummy.model","rb") as f:
my_object = joblib.load(f)
assert my_object is not None
tasks.py
:
import joblib
@celery_app.task()
def dummy_task():
_ = joblib.load("dummy.model")
如果我们先运行lit_model.py
文件,然后再运行load_model.py
,它就可以正常工作。但是,当此操作作为Celery任务运行并尝试延迟dummy_task
时,将得到以下信息:
File "/usr/local/lib/python3.7/site-packages/celery/app/trace.py",line 412,in trace_task
R = retval = fun(*args,**kwargs)
File "/usr/local/lib/python3.7/site-packages/celery/app/trace.py",line 704,in __protected_call__
return self.run(*args,**kwargs)
File "/lit_service/tasks.py",line 32,in dummy_task
_ = joblib.load("dummy.model")
File "/usr/local/lib/python3.7/site-packages/joblib/numpy_pickle.py",line 585,in load
obj = _unpickle(fobj,filename,mmap_mode)
File "/usr/local/lib/python3.7/site-packages/joblib/numpy_pickle.py",line 504,in _unpickle
obj = unpickler.load()
File "/usr/local/lib/python3.7/pickle.py",line 1088,in load
dispatch[key[0]](self)
File "/usr/local/lib/python3.7/pickle.py",line 1376,in load_global
klass = self.find_class(module,name)
File "/usr/local/lib/python3.7/pickle.py",line 1430,in find_class
return getattr(sys.modules[module],name)
AttributeError: module 'celery.bin.celery' has no attribute 'MyClass'
在application.py
中设置了Flask Celery worker:
celery_app = Celery(
__name__,broker=settings.CELERY_broKER_URL,backend=settings.CELERY_broKER_URL
)
def create_app(**config_overrides):
app = Flask(__name__)
# start celery
celery_app.conf.update(app.config)
celery_app.conf.imports = ["tasks"]
celery_app.conf.task_routes = {"tasks.*": {"queue": "lit_tasks"}}
return app
然后像这样将工人作为Docker容器生成:
celery worker -A celery_worker.celery_app --loglevel=INFO --queues lit_tasks
关于哪里看的任何线索?
解决方法
尝试在自己的文件(例如MyClass
)中定义my_class.py
,然后将from my_class import MyClass
添加到fit_model.py
。