Scikit / Tensorflow交叉验证:无法腌制_thread.RLock对象

问题描述

对于神经网络回归预测任务cross_val_predict会引发错误(完全错误在下面,使用的模型下面)

TypeError: can't pickle _thread.RLock objects

当将其与表格的输入数据一起使用时

X_train.shape=1200,18,15 
y_train.shape=1200,1 

以及NN中的以下内容

def twds_model(layer1=32,layer2=32,layer3=16,dropout_rate=0.5,optimizer='Adam',learning_rate=0.001,activation='relu',loss='mse'):#,n_jobs=1):layer3=80,model = Sequential()
    model.add(Bidirectional(GRU(layer1,return_sequences=True),input_shape=(X_train.shape[1],X_train.shape[2])))
    model.add(AveragePooling1D(2))
    model.add(Conv1D(layer2,3,activation=activation,padding='same',name='extractor'))
    model.add(Flatten())
    model.add(Dense(layer3,activation=activation))
    model.add(Dropout(dropout_rate))
    model.add(Dense(1))
    model.compile(optimizer=optimizer,loss=loss)
    return model

twds_model=twds_model()
print(twds_model.summary())
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
bidirectional_4 (Bidirection (None,64)            9216      
_________________________________________________________________
average_pooling1d_1 (Average (None,9,64)             0         
_________________________________________________________________
extractor (Conv1D)           (None,32)             6176      
_________________________________________________________________
flatten_1 (Flatten)          (None,288)               0         
_________________________________________________________________
dense_3 (Dense)              (None,16)                4624      
_________________________________________________________________
dropout_4 (Dropout)          (None,16)                0         
_________________________________________________________________
dense_4 (Dense)              (None,1)                 17        
=================================================================
Total params: 20,033
Trainable params: 20,033
Non-trainable params: 0
_________________________________________________________________
None

model_twds=KerasRegressor(build_fn=twds_model,batch_size=144,epochs=6)#12

完整错误:

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-29-b2465d2e01ce> in <module>
      4                n_jobs=1,5                cv=4,----> 6                verbose=2)

~\Anaconda3\envs\Tensorflow\lib\site-packages\sklearn\utils\validation.py in inner_f(*args,**kwargs)
     70                           FutureWarning)
     71         kwargs.update({k: arg for k,arg in zip(sig.parameters,args)})
---> 72         return f(**kwargs)
     73     return inner_f
     74 

~\Anaconda3\envs\Tensorflow\lib\site-packages\sklearn\model_selection\_validation.py in cross_val_predict(estimator,X,y,groups,cv,n_jobs,verbose,fit_params,pre_dispatch,method)
    771     prediction_blocks = parallel(delayed(_fit_and_predict)(
    772         clone(estimator),train,test,method)
--> 773         for train,test in cv.split(X,groups))
    774 
    775     # Concatenate the predictions

~\Anaconda3\envs\Tensorflow\lib\site-packages\joblib\parallel.py in __call__(self,iterable)
    919             # remaining jobs.
    920             self._iterating = False
--> 921             if self.dispatch_one_batch(iterator):
    922                 self._iterating = self._original_iterator is not None
    923 

~\Anaconda3\envs\Tensorflow\lib\site-packages\joblib\parallel.py in dispatch_one_batch(self,iterator)
    752             tasks = BatchedCalls(itertools.islice(iterator,batch_size),753                                  self._backend.get_nested_backend(),--> 754                                  self._pickle_cache)
    755             if len(tasks) == 0:
    756                 # No more tasks available in the iterator: tell caller to stop.

~\Anaconda3\envs\Tensorflow\lib\site-packages\joblib\parallel.py in __init__(self,iterator_slice,backend_and_jobs,pickle_cache)
    208 
    209     def __init__(self,pickle_cache=None):
--> 210         self.items = list(iterator_slice)
    211         self._size = len(self.items)
    212         if isinstance(backend_and_jobs,tuple):

~\Anaconda3\envs\Tensorflow\lib\site-packages\sklearn\model_selection\_validation.py in <genexpr>(.0)
    771     prediction_blocks = parallel(delayed(_fit_and_predict)(
    772         clone(estimator),groups))
    774 
    775     # Concatenate the predictions

~\Anaconda3\envs\Tensorflow\lib\site-packages\sklearn\utils\validation.py in inner_f(*args,args)})
---> 72         return f(**kwargs)
     73     return inner_f
     74 

~\Anaconda3\envs\Tensorflow\lib\site-packages\sklearn\base.py in clone(estimator,safe)
     85     new_object_params = estimator.get_params(deep=False)
     86     for name,param in new_object_params.items():
---> 87         new_object_params[name] = clone(param,safe=False)
     88     new_object = klass(**new_object_params)
     89     params_set = new_object.get_params(deep=False)

~\Anaconda3\envs\Tensorflow\lib\site-packages\sklearn\utils\validation.py in inner_f(*args,safe)
     69     elif not hasattr(estimator,'get_params') or isinstance(estimator,type):
     70         if not safe:
---> 71             return copy.deepcopy(estimator)
     72         else:
     73             if isinstance(estimator,type):

~\Anaconda3\envs\Tensorflow\lib\copy.py in deepcopy(x,memo,_nil)
    178                     y = x
    179                 else:
--> 180                     y = _reconstruct(x,*rv)
    181 
    182     # If is its own copy,don't memoize.

~\Anaconda3\envs\Tensorflow\lib\copy.py in _reconstruct(x,func,args,state,listiter,dictiter,deepcopy)
    278     if state is not None:
    279         if deep:
--> 280             state = deepcopy(state,memo)
    281         if hasattr(y,'__setstate__'):
    282             y.__setstate__(state)

~\Anaconda3\envs\Tensorflow\lib\copy.py in deepcopy(x,_nil)
    148     copier = _deepcopy_dispatch.get(cls)
    149     if copier:
--> 150         y = copier(x,memo)
    151     else:
    152         try:

~\Anaconda3\envs\Tensorflow\lib\copy.py in _deepcopy_dict(x,deepcopy)
    238     memo[id(x)] = y
    239     for key,value in x.items():
--> 240         y[deepcopy(key,memo)] = deepcopy(value,memo)
    241     return y
    242 d[dict] = _deepcopy_dict

~\Anaconda3\envs\Tensorflow\lib\copy.py in deepcopy(x,memo)
    151     else:
    152         try:

~\Anaconda3\envs\Tensorflow\lib\copy.py in _deepcopy_list(x,deepcopy)
    213     append = y.append
    214     for a in x:
--> 215         append(deepcopy(a,memo))
    216     return y
    217 d[list] = _deepcopy_list

~\Anaconda3\envs\Tensorflow\lib\copy.py in deepcopy(x,_nil)
    167                     reductor = getattr(x,"__reduce_ex__",None)
    168                     if reductor:
--> 169                         rv = reductor(4)
    170                     else:
    171                         reductor = getattr(x,"__reduce__",None)

TypeError: can't pickle _thread.RLock objects

当使用普通的tensorflow .fit执行时,弹出以下错误。也许这与第一个问题有关。

Train on 10420 samples,validate on 1697 samples
Epoch 1/8

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-30-3f5256ff03ec> in <module>
----> 1 Test_tdws=twds_model.fit(X_train,y_train,epochs=8,verbose=2,validation_split=(0.14),shuffle=False) #callbacks=[tensorboard])

~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self,x,batch_size,epochs,callbacks,validation_split,validation_data,shuffle,class_weight,sample_weight,initial_epoch,steps_per_epoch,validation_steps,max_queue_size,workers,use_multiprocessing,**kwargs)
    878           initial_epoch=initial_epoch,879           steps_per_epoch=steps_per_epoch,--> 880           validation_steps=validation_steps)
    881 
    882   def evaluate(self,~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training_arrays.py in model_iteration(model,inputs,targets,sample_weights,val_inputs,val_targets,val_sample_weights,mode,validation_in_fit,**kwargs)
    327 
    328         # Get outputs.
--> 329         batch_outs = f(ins_batch)
    330         if not isinstance(batch_outs,list):
    331           batch_outs = [batch_outs]

~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\backend.py in __call__(self,inputs)
   3074 
   3075     fetched = self._callable_fn(*array_vals,-> 3076                                 run_metadata=self.run_metadata)
   3077     self._call_fetch_callbacks(fetched[-len(self._fetches):])
   3078     return nest.pack_sequence_as(self._outputs_structure,~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\client\session.py in __call__(self,*args,**kwargs)
   1437           ret = tf_session.TF_SessionRunCallable(
   1438               self._session._session,self._handle,status,-> 1439               run_metadata_ptr)
   1440         if run_metadata:
   1441           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\framework\errors_impl.py in __exit__(self,type_arg,value_arg,traceback_arg)
    526             None,None,527             compat.as_text(c_api.TF_Message(self.status.status)),--> 528             c_api.TF_GetCode(self.status.status))
    529     # Delete the underlying status object from memory otherwise it stays alive
    530     # as there is a reference to status from this from the traceback due to

InvalidArgumentError: Incompatible shapes: [144,1] vs. [144,1]
     [[{{node loss_2/dense_4_loss/sub}}]]
     [[{{node loss_2/mul}}]]

使用的版本是:

Package                Version
---------------------- --------------------
-                      nsorflow-gpu
-ensorflow-gpu         1.13.1
-rotobuf               3.11.3
-umpy                  1.18.1
absl-py                0.9.0
antlr4-python3-runtime 4.8
asn1crypto             1.3.0
astor                  0.7.1
astropy                3.2.1
astunparse             1.6.3
attrs                  19.3.0
audioread              2.1.8
autopep8               1.5.3
backcall               0.1.0
beautifulsoup4         4.9.0
bezier                 0.8.0
bkcharts               0.2
bleach                 3.1.4
blis                   0.2.4
bokeh                  1.1.0
boto3                  1.9.253
botocore               1.12.253
Bottleneck             1.3.2
cachetools             4.1.0
certifi                2020.4.5.1
cffi                   1.14.0
chardet                3.0.4
click                  6.7
cloudpickle            0.5.3
cmdstanpy              0.4.0
color                  0.1
colorama               0.4.3
colorcet               0.9.1
convertdate            2.2.1
copulas                0.2.5
cryptography           2.8
ctgan                  0.2.1
cycler                 0.10.0
cymem                  2.0.2
Cython                 0.29.17
dash                   0.26.0
dash-core-components   0.27.2
dash-html-components   0.11.0
dash-renderer          0.13.2
dask                   0.18.1
dataclasses            0.6
datashader             0.7.0
datashape              0.5.2
datawig                0.1.10
deap                   1.3.0
decorator              4.4.2
defusedxml             0.6.0
deltapy                0.1.1
dill                   0.2.9
distributed            1.22.1
docutils               0.14
entrypoints            0.3
ephem                  3.7.7.1
et-xmlfile             1.0.1
exrex                  0.10.5
Faker                  4.0.3
fastai                 1.0.60
fastprogress           0.2.2
fbprophet              0.6
fire                   0.3.1
Flask                  1.0.2
Flask-Compress         1.4.0
future                 0.17.1
gast                   0.3.3
geojson                2.4.1
geomet                 0.2.0.post2
google-auth            1.14.0
google-auth-oauthlib   0.4.1
google-pasta           0.2.0
gplearn                0.4.1
graphviz               0.13.2
grpcio                 1.29.0
h5py                   2.10.0
HeapDict               1.0.0
holidays               0.10.2
holoviews              1.12.1
html2text              2018.1.9
hyperas                0.4.1
hyperopt               0.1.2
idna                   2.6
imageio                2.5.0
imbalanced-learn       0.3.3
imblearn               0.0
importlib-metadata     1.5.0
impyute                0.0.8
ipykernel              5.1.4
ipython                7.13.0
ipython-genutils       0.2.0
ipywidgets             7.5.1
itsdangerous           0.24
jdcal                  1.4
jedi                   0.16.0
Jinja2                 2.11.1
jmespath               0.9.5
joblib                 0.13.2
jsonschema             3.2.0
jupyter                1.0.0
jupyter-client         6.1.2
jupyter-console        6.0.0
jupyter-core           4.6.3
Keras                  2.2.5
Keras-Applications     1.0.8
Keras-Preprocessing    1.1.2
keras-rectified-adam   0.17.0
kiwisolver             1.2.0
korean-lunar-calendar  0.2.1
librosa                0.7.2
llvmlite               0.32.1
lml                    0.0.1
locket                 0.2.0
LunarCalendar          0.0.9
Markdown               2.6.11
MarkupSafe             1.1.1
matplotlib             3.2.1
missingpy              0.2.0
mistune                0.8.4
mkl-fft                1.0.15
mkl-random             1.1.0
mkl-service            2.3.0
mock                   4.0.2
msgpack                0.5.6
multipledispatch       0.6.0
murmurhash             1.0.2
mxnet                  1.4.1
nb-conda               2.2.1
nb-conda-kernels       2.2.3
nbconvert              5.6.1
nbformat               5.0.4
nbstripout             0.3.7
networkx               2.1
notebook               6.0.3
numba                  0.49.1
numexpr                2.7.1
numpy                  1.19.0
oauthlib               3.1.0
olefile                0.46
opencv-python          4.2.0.34
openpyxl               2.5.5
opt-einsum             3.2.1
packaging              20.3
pandas                 1.0.3
pandasvault            0.0.3
pandocfilters          1.4.2
param                  1.9.0
parso                  0.6.2
partd                  0.3.8
patsy                  0.5.1
pbr                    5.1.3
pickleshare            0.7.5
Pillow                 7.0.0
pip                    20.0.2
plac                   0.9.6
plotly                 4.7.1
plotly-express         0.4.1
preshed                2.0.1
prometheus-client      0.7.1
prompt-toolkit         3.0.4
protobuf               3.11.3
psutil                 5.4.7
py                     1.8.0
pyasn1                 0.4.8
pyasn1-modules         0.2.8
pycodestyle            2.6.0
pycparser              2.20
pyct                   0.4.5
pyensae                1.3.839
pyexcel                0.5.8
pyexcel-io             0.5.7
Pygments               2.6.1
pykalman               0.9.5
PyMeeus                0.3.7
pymongo                3.8.0
pyOpenSSL              19.1.0
pyparsing              2.4.7
pypi                   2.1
pyquickhelper          1.9.3418
pyrsistent             0.16.0
PySocks                1.7.1
pystan                 2.19.1.1
python-dateutil        2.8.1
pytz                   2019.3
pyviz-comms            0.7.2
PyWavelets             0.5.2
pywin32                227
pywinpty               0.5.7
PyYAML                 5.3.1
pyzmq                  18.1.1
qtconsole              4.4.4
rdt                    0.2.1
RegscorePy             1.1
requests               2.23.0
requests-oauthlib      1.3.0
resampy                0.2.2
retrying               1.3.3
rsa                    4.0
s3transfer             0.2.1
scikit-image           0.15.0
scikit-learn           0.23.2
scipy                  1.4.1
sdv                    0.3.2
seaborn                0.9.0
seasonal               0.3.1
Send2Trash             1.5.0
sentinelsat            0.12.2
setuptools             46.3.0
setuptools-git         1.2
six                    1.14.0
sklearn                0.0
sortedcontainers       2.0.4
SoundFile              0.10.3.post1
soupsieve              2.0
spacy                  2.1.8
srsly                  0.1.0
statsmodels            0.9.0
stopit                 1.1.2
sugartensor            1.0.0.2
ta                     0.5.25
tb-nightly             1.14.0a20190603
tblib                  1.3.2
tensorboard            1.13.1
tensorboard-plugin-wit 1.6.0.post3
tensorflow-estimator   1.13.0
tensorflow-gpu         1.13.1
termcolor              1.1.0
terminado              0.8.3
testpath               0.4.4
text-unidecode         1.3
texttable              1.4.0
tf-estimator-nightly   1.14.0.dev2019060501
Theano                 1.0.4
thinc                  7.0.8
threadpoolctl          2.1.0
toml                   0.10.1
toolz                  0.10.0
torch                  1.4.0
torchvision            0.5.0
tornado                6.0.4
TPOT                   0.10.2
tqdm                   4.45.0
traitlets              4.3.3
transforms3d           0.3.1
tsaug                  0.2.1
typeguard              2.7.1
typing                 3.6.6
update-checker         0.16
urllib3                1.22
utm                    0.4.2
wasabi                 0.2.2
wcwidth                0.1.9
webencodings           0.5.1
Werkzeug               1.0.1
wheel                  0.34.2
widgetsnbextension     3.5.1
win-inet-pton          1.1.0
wincertstore           0.2
wrapt                  1.11.2
xarray                 0.10.8
xlrd                   1.1.0
yahoo-historical       0.3.2
zict                   0.1.3
zipp                   2.2.0

对于每一个指向正在运行的代码的提示,都表示非常感谢;-)!

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...