Python多重处理,变量的预初始化

问题描述

我正在尝试使用多处理模块并行化我的代码。我正在处理的代码分两个步骤工作。在第一步中,我初始化一个类,该类计算并保存多个变量,这些变量将在第二步中使用。在第二步中,程序使用先前初始化的变量执行计算。第一步的变量不做任何修改。第一步的计算时间并不重要,而在第二步中则很重要,因为它必须按顺序顺序被调用数百次。下面是代码结构和ist输出的最小构造示例。

import numpy as np
import time
from multiprocessing import Pool

class test:
    def __init__(self):
        self.r = np.ones(10000000)


    def f(self,init):
        summed = 0
        for i in range(0,init):
            summed = summed + i
        return summed


if __name__ == "__main__":
    # first step 
    func = test()
    
    
    # second step
    # sequential
    start_time = time.time()
    for i in [1000000,1000000,1000000]:
        func.f(i)
    print('Sequential: ',time.time()-start_time)

    
    # parallel
    start_time = time.time()
    pool = Pool(processes=None)
    result = pool.starmap(func.f,[[1000000],[1000000],[1000000]])
    print('Parallel: ',time.time()-start_time)

输出
顺序:0.2673146724700928
平行:1.5638213157653809

据我了解,由于必须将类测试的变量r传输到所有工作进程,因此多处理变得越来越慢。为了避免这种情况,我需要在启动f之前在每个worker上初始化类。多处理有可能吗?还有其他工具吗?

解决方法

只需创建函数

def my_function(value):
    func = Test()
    return func.f(value)

甚至

def my_function(value):
    return Test().f(value)

并使用它

result = pool.starmap(my_function,[[1000000],[1000000],[1000000]])

多处理不适用于lambda,因此您不能使用

pool.starmap(lambda value:Test().f(value),...)

functools.partial()可能不起作用,因此您不能使用它代替lambda


最小的工作示例

import numpy as np
import time
from multiprocessing import Pool

class Test:  # PEP8: `CamelCaseNames` for classes
    
    def __init__(self):
        self.r = np.ones(10000000)

    def f(self,init):
        summed = 0
        for i in range(init):
            summed = summed + i
        return summed

def my_function(value):
    func = Test()
    return func.f(value)

if __name__ == "__main__":

    data = [[1000000] for x in range(30)]

    # first step 
    func = Test()
    
    # second step
    # sequential
    start_time = time.time()
    for i in data:
        func.f(*i)   # `*i` like in starmap
    print('Sequential:',time.time()-start_time)
    
    # parallel 1
    start_time = time.time()
    pool = Pool(processes=None)
    result = pool.starmap(func.f,data)
    print('Parallel 1:',time.time()-start_time)
    
    # parallel 2
    start_time = time.time()
    pool = Pool(processes=None)
    result = pool.starmap(my_function,data)
    print('Parallel 2:',time.time()-start_time)
    

我的结果:

Sequential: 3.0593459606170654
Parallel 1: 5.2161490917205810
Parallel 2: 1.8350131511688232
,

我已经通过使用多处理模块中的Pipe函数解决了该问题。在第一步中,我现在可以初始化变量并设置多处理环境。然后,我使用Pipe函数来传输输入数据。

对于“ self.r = np.ones(100000000)”
平行管道:0.8008558750152588
平行2:18.51273012161255

对于“ self.r = np.ones(10000000)”
平行管道:0.71409010887146 平行2:1.4551067352294922

import numpy as np
import time
import multiprocessing as mp


class Test:  # PEP8: `CamelCaseNames` for classes
    def __init__(self):
        self.r = np.ones(100000000)

    def f(self,init):
        summed = 0
        for i in range(init):
            summed = summed + i
        return summed


def my_function(value):
    func = Test()
    return func.f(value)


class Connection:
    def __init__(self):
        self.process = {}
        self.parent = {}
        self.child = {}

    def add(self,hub,process,parent_conn,child_conn):
        self.process[hub] = process
        self.parent[hub] = parent_conn
        self.child[hub] = child_conn


def multi_run(child_conn,func,i):
    while 1:
        init = child_conn.recv()
        data = func.f(init)
        child_conn.send(data)


if __name__ == "__main__":
    N_processes = 4

    func = Test()
    conn = Connection()
    # First step
    for i in range(N_processes):
        parent_conn,child_conn = mp.Pipe()
        process = mp.Process(target=multi_run,args=(child_conn,i))
        conn.add(i,child_conn)
        process.start()

    start_time = time.time()
    data = [[1000000,x] for x in range(30)]
    # Second step
    for i,j in data:
        conn.parent[j % N_processes].send(i)
    for i,j in data:
        conn.parent[j % N_processes].recv()
    print('Parallel piped:',time.time()-start_time)

    data = [[1000000] for x in range(30)]
    # parallel 2
    start_time = time.time()
    pool = mp.Pool(processes=None)
    result = pool.starmap(my_function,time.time()-start_time)