问题描述
我有一个 LSH 表生成器实用程序类,如下所示(引用自 here):
class BuildLSHTable:
def __init__(self,hash_size=8,dim=2048,num_tables=10,lsh_file="lsh_table.pkl"):
self.hash_size = hash_size
self.dim = dim
self.num_tables = num_tables
self.lsh = LSH(self.hash_size,self.dim,self.num_tables)
self.embedding_model = embedding_model
self.lsh_file = lsh_file
def train(self,training_files):
for id,training_file in enumerate(training_files):
image,label = training_file
if len(image.shape) < 4:
image = image[None,...]
features = self.embedding_model.predict(image)
self.lsh.add(id,features,label)
with open(self.lsh_file,"wb") as handle:
pickle.dump(self.lsh,handle,protocol=pickle.HIGHEST_PROTOCOL)
然后我执行以下命令来构建我的 LSH 表:
training_files = zip(images,labels)
lsh_builder = BuildLSHTable()
lsh_builder.train(training_files)
现在,当我尝试通过 Apache Beam(下面的代码)执行此操作时,它会抛出:
TypeError: can't pickle tensorflow.python._pywrap_tf_session.TF_Operation objects
用于梁的代码:
def generate_lsh_table(args):
options = beam.options.pipeline_options.PipelineOptions(**args)
args = namedtuple("options",args.keys())(*args.values())
with beam.Pipeline(args.runner,options=options) as pipeline:
(
pipeline
| 'Build LSH Table' >> beam.Map(
args.lsh_builder.train,args.training_files)
)
这就是我调用光束运行器的方式:
args = {
"runner": "DirectRunner","lsh_builder": lsh_builder,"training_files": training_files
}
generate_lsh_table(args)
解决方法
Apache Beam 管道在执行前应转换为标准(例如 proto)格式。作为此的一部分,某些管道对象(例如 DoFn
)被序列化(挑选)。如果您的 DoFn
具有无法序列化的实例变量,则此过程无法继续。
解决此问题的一种方法是在执行期间加载/定义此类实例对象或模块,而不是在管道提交期间创建和存储此类对象。这可能需要调整您的管道。