如何在python中使用多处理而不复制大型只读字典

问题描述

我有一个查找表 LUT,它是一个非常大的字典 (24G)。 我有数百万个输入可以对其进行查询

我想将数百万个输入拆分为 32 个作业,并并行运行它们。 由于空间限制,我无法运行多个python脚本,因为这会导致内存过载。

我想使用 multiprocessing 模块只加载一次 LUT,然后让不同的进程查找它,同时将它作为全局变量共享,而不必复制它。

但是,当我查看 htop 时,似乎每个子进程都在重新创建 LUT?我提出此声明是因为根据 VIRTRESSHR。数字非常高。 但同时我没有看到 Mem 行中使用的额外内存,它从 11Gb 增加到 12.3G 并且只是悬停在那里。

所以我很困惑,是不是在每个子进程中重新创建 LUT ? 我应该如何确保我正在运行并行工作,而不在每个子进程中复制 LUT? 代码如下图所示。

enter image description here

在这个实验中我只使用了 1Gb 的 LUT,所以不要担心它不是 24Gb)

import os,sys,time,pprint,pdb,datetime
import threading,multiprocessing

## Print the process/thread details
def getDetails(idx):
    pid = os.getpid()
    threadName = threading.current_thread().name
    processName = multiprocessing.current_process().name
    print(f"{idx})\tpid={pid}\tprocessName={processName}\tthreadName={threadName} ")
    return pid,threadName,processName

def ComplexAlgorithm(value):
    # Instead of just lookup like this
    # the real algorithm is some complex algorithm that performs some search
    return value in LUT

## Querying the 24Gb LUT from my millions of lines of input
def PerformMatching(idx,NumberOfLines):
    pid,processName = getDetails(idx)
    NumberMatches = 0
    for _ in range(NumberOfLines):
        # I will actually read the contents from my file live,# but here just assume i generate random numbers
        value = random.randint(-100,100)
        if ComplexAlgorithm(value): NumberMatches += 1
    print(f"\t{idx}) | LUT={len(LUT)} | NumberMatches={NumberMatches} | done")

if __name__ == "__main__":

    ## Init
    num_processes = 9
    # this is just a pseudo-call to show you the structure of my LUT,the real one is larger
    LUT = (dict(i,set([i])) for i in range(1000))

    ## Store the multiple filenames
    listofLists = []
    for idx in range(num_processes):
        NumberOfLines = 10000
        listofLists.append( NumberOfLines )

    ## Init the processes
    ProcessList = []
    for processIndex in range(num_processes):
        ProcessList.append( 
            multiprocessing.Process(
                target=PerformMatching,args=(processIndex,listofLists[processIndex])
            )
        )
        ProcessList[processIndex].start()

    ## Wait until the process terminates.
    for processIndex in range(num_processes):
        ProcessList[processIndex].join()

    ## Done

解决方法

如果您想走使用 multiprocessing.Manager 的路线,您可以这样做。权衡是字典由对存在于不同地址空间中的实际字典的代理 的引用表示,因此每个字典引用都会导致等效的远程过程调用。换句话说,与“常规”字典相比,访问速度要慢得多。

在下面的演示程序中,我只为我的托管字典定义了几个方法,但是您可以定义任何您需要的方法。我还使用了多处理池,而不是显式启动单个进程;你可以考虑这样做。

from multiprocessing.managers import BaseManager,BaseProxy
from multiprocessing import Pool
from functools import partial

def worker(LUT,key):
    return LUT[key]


class MyDict:
    def __init__(self):
        """ initialize the dictionary """
        # the very large dictionary reduced for demo purposes:
        self._dict = {i: i for i in range(100)}

    def get(self,obj,default=None):
        """ delegates to underlying dict """
        return self._dict.get(obj,default)

    def __getitem__(self,obj):
        """ delegates to underlying dict """
        return self._dict[obj]

class MyDictManager(BaseManager):
    pass

class MyDictProxy(BaseProxy):
    _exposed_ = ('get','__getitem__')

    def get(self,*args,**kwargs):
        return self._callmethod('get',args,kwargs)

    def __getitem__(self,**kwargs):
        return self._callmethod('__getitem__',kwargs)


def main():
    MyDictManager.register('MyDict',MyDict,MyDictProxy)
    with MyDictManager() as manager:
        my_dict = manager.MyDict()
        pool = Pool()
        # pass proxy instead of actual LUT:
        results = pool.map(partial(worker,my_dict),range(100))
        print(sum(results))

if __name__ == '__main__':
    main()

打印:

4950

讨论

Python 带有一个内置的托管 dict 类,可通过 multiprocessing.Manager().dict() 获得。但是根据我之前的评论,每次访问都相对昂贵,使用这样的字典初始化如此大量的条目将是非常低效的。在我看来,创建我们自己的托管类会更便宜,它有一个底层的“常规”字典,可以在构建托管类时直接初始化,而不是通过代理引用。虽然 Python 附带的托管 dict 确实可以使用已经构建的字典进行实例化,从而避免了效率低下的问题,但我担心的是内存效率会受到影响,因为您会有两个字典实例,即“常规”字典和“托管”字典。