使用 Apache Beam 和 Dataflow 构建 LSH 表的最佳方法

问题描述

我有一个 LSH 表生成器实用程序类,如下所示(引用自 here):

class BuildLSHTable:
    def __init__(self,hash_size=8,dim=2048,num_tables=10,lsh_file="lsh_table.pkl"):
        self.hash_size = hash_size
        self.dim = dim
        self.num_tables = num_tables
        self.lsh = LSH(self.hash_size,self.dim,self.num_tables)
        self.embedding_model = embedding_model
        self.lsh_file = lsh_file

    def train(self,training_files):
        for id,training_file in enumerate(training_files):
            image,label = training_file
            if len(image.shape) < 4:
                image = image[None,...]
            features = self.embedding_model.predict(image)
            self.lsh.add(id,features,label)
        
        with open(self.lsh_file,"wb") as handle:
            pickle.dump(self.lsh,handle,protocol=pickle.HIGHEST_PROTOCOL)    

然后我执行以下命令来构建我的 LSH 表:

training_files = zip(images,labels)
lsh_builder = BuildLSHTable()
lsh_builder.train(training_files)

现在,当我尝试通过 Apache Beam(下面的代码)执行此操作时,它会抛出:

TypeError: can't pickle tensorflow.python._pywrap_tf_session.TF_Operation objects

用于梁的代码

def generate_lsh_table(args):
    options = beam.options.pipeline_options.PipelineOptions(**args)
    args = namedtuple("options",args.keys())(*args.values())

    with beam.Pipeline(args.runner,options=options) as pipeline:
        (
            pipeline
            | 'Build LSH Table' >> beam.Map(
                args.lsh_builder.train,args.training_files)
        )

这就是我调用光束运行器的方式:

args = {
    "runner": "DirectRunner","lsh_builder": lsh_builder,"training_files": training_files
}

generate_lsh_table(args)

解决方法

Apache Beam 管道在执行前应转换为标准(例如 proto)格式。作为此的一部分,某些管道对象(例如 DoFn)被序列化(挑选)。如果您的 DoFn 具有无法序列化的实例变量,则此过程无法继续。

解决此问题的一种方法是在执行期间加载/定义此类实例对象或模块,而不是在管道提交期间创建和存储此类对象。这可能需要调整您的管道。