带有sklearn的Dask-ML随机森林导致连接关闭

问题描述

我正在尝试使用dask-ML训练模型。我的最终目标是对大于内存的数据集进行预测,因此我正在利用dask的ParallelPostFit包装器在相对较小的数据集(4 Gb)上训练模型,期望以后可以对较大的数据帧进行预测。我正在连接一个由50名工人组成的Yarn集群,将我的数据从实木复合地板加载到dask数据框中,创建了管道并进行了培训。训练是可行的,但是当我尝试对保留的测试集进行评估时,我遇到了问题。当我使用sklearn的LogisticRegression作为分类器时,训练和预测将成功运行。但是,当我使用带有100个估计量的sklearn随机森林时,训练步骤成功运行,但是根据预测,我得到了以下错误。我在预测计算步骤中注意到,在断开连接错误之前,我的本地计算机内存使用量开始激增。当我将RF估算器的数量减少到10时,预测步骤将成功运行。谁能帮助我了解发生了什么事?

我的代码(简明)

cluster = YarnCluster(environment=path_to_packed_conda_env,n_workers=50,worker_vcores=10,worker_env=worker_env,worker_restarts=10,scheduler_memory='10GiB',scheduler_vcores=5,worker_memory='20GiB')
cluster.adapt(minimum=50,maximum=100)

# connect client
client = Client(cluster)

# instantiate classifier
clf_rfc = RandomForestClassifier(n_estimators=100,n_jobs=5,criterion='gini',max_features='auto',min_samples_split = 50,class_weight='balanced',verbose=1,random_state=RANDOM_STATE)

# train/test split
X_train,X_test,y_train,y_test = train_test_split_dd(X,y,train_size = 0.7,random_state=RANDOM_STATE)

# build pipeline
pipe = ParallelPostFit(Pipeline(steps=[
    ('preprocessor',preprocessor),('classifier_',clone(clf_rfc))
]))

X_train = X_train.persist()
y_train = y_train.persist()

# Train
pipe.fit(X_train,y_train)

# Evaluate
X_test = X_test.persist()
y_test = y_test.persist()

print('computing ypred')
y_preds = pipe.predict(X_test).compute()

print('computing yprob')
y_probs = pipe.predict_proba(X_test).compute()

输出

computing ypred
distributed.batched - INFO - Batched Comm Closed: in <closed TCP>: ConnectionResetError: [Errno 104] Connection reset by peer
---------------------------------------------------------------------------
CancelledError                            Traceback (most recent call last)
<ipython-input-108-23f303f7584c> in <module>
      5 
      6 print('computing ypred')
----> 7 y_preds = [pipe.predict(X_test).compute() for pipe in pipes]
      8 
      9 print('computing yprob')

<ipython-input-108-23f303f7584c> in <listcomp>(.0)
      5 
      6 print('computing ypred')
----> 7 y_preds = [pipe.predict(X_test).compute() for pipe in pipes]
      8 
      9 print('computing yprob')

~/.conda/envs/boa/lib/python3.7/site-packages/dask/base.py in compute(self,**kwargs)
    164         dask.base.compute
    165         """
--> 166         (result,) = compute(self,traverse=False,**kwargs)
    167         return result
    168 

~/.conda/envs/boa/lib/python3.7/site-packages/dask/base.py in compute(*args,**kwargs)
    435     keys = [x.__dask_keys__() for x in collections]
    436     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 437     results = schedule(dsk,keys,**kwargs)
    438     return repack([f(r,*a) for r,(f,a) in zip(results,postcomputes)])
    439 

~/.conda/envs/boa/lib/python3.7/site-packages/distributed/client.py in get(self,dsk,restrictions,loose_restrictions,resources,sync,asynchronous,direct,retries,priority,fifo_timeout,actors,**kwargs)
   2593                     should_rejoin = False
   2594             try:
-> 2595                 results = self.gather(packed,asynchronous=asynchronous,direct=direct)
   2596             finally:
   2597                 for f in futures.values():

~/.conda/envs/boa/lib/python3.7/site-packages/distributed/client.py in gather(self,futures,errors,asynchronous)
   1891                 direct=direct,1892                 local_worker=local_worker,-> 1893                 asynchronous=asynchronous,1894             )
   1895 

~/.conda/envs/boa/lib/python3.7/site-packages/distributed/client.py in sync(self,func,callback_timeout,*args,**kwargs)
    778         else:
    779             return sync(
--> 780                 self.loop,callback_timeout=callback_timeout,**kwargs
    781             )
    782 

~/.conda/envs/boa/lib/python3.7/site-packages/distributed/utils.py in sync(loop,**kwargs)
    346     if error[0]:
    347         typ,exc,tb = error[0]
--> 348         raise exc.with_traceback(tb)
    349     else:
    350         return result[0]

~/.conda/envs/boa/lib/python3.7/site-packages/distributed/utils.py in f()
    330             if callback_timeout is not None:
    331                 future = asyncio.wait_for(future,callback_timeout)
--> 332             result[0] = yield future
    333         except Exception as exc:
    334             error[0] = sys.exc_info()

~/.conda/envs/boa/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

CancelledError: 
distributed.client - ERROR - Failed to reconnect to scheduler after 10.00 seconds,closing client
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
concurrent.futures._base.CancelledError

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)