问题描述
对于神经网络回归预测任务cross_val_predict会引发错误(完全错误在下面,使用的模型下面)
TypeError: can't pickle _thread.RLock objects
当将其与表格的输入数据一起使用时
X_train.shape=1200,18,15
y_train.shape=1200,1
以及NN中的以下内容
def twds_model(layer1=32,layer2=32,layer3=16,dropout_rate=0.5,optimizer='Adam',learning_rate=0.001,activation='relu',loss='mse'):#,n_jobs=1):layer3=80,model = Sequential()
model.add(Bidirectional(GRU(layer1,return_sequences=True),input_shape=(X_train.shape[1],X_train.shape[2])))
model.add(AveragePooling1D(2))
model.add(Conv1D(layer2,3,activation=activation,padding='same',name='extractor'))
model.add(Flatten())
model.add(Dense(layer3,activation=activation))
model.add(Dropout(dropout_rate))
model.add(Dense(1))
model.compile(optimizer=optimizer,loss=loss)
return model
twds_model=twds_model()
print(twds_model.summary())
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
bidirectional_4 (Bidirection (None,64) 9216
_________________________________________________________________
average_pooling1d_1 (Average (None,9,64) 0
_________________________________________________________________
extractor (Conv1D) (None,32) 6176
_________________________________________________________________
flatten_1 (Flatten) (None,288) 0
_________________________________________________________________
dense_3 (Dense) (None,16) 4624
_________________________________________________________________
dropout_4 (Dropout) (None,16) 0
_________________________________________________________________
dense_4 (Dense) (None,1) 17
=================================================================
Total params: 20,033
Trainable params: 20,033
Non-trainable params: 0
_________________________________________________________________
None
和
model_twds=KerasRegressor(build_fn=twds_model,batch_size=144,epochs=6)#12
完整错误:
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-29-b2465d2e01ce> in <module>
4 n_jobs=1,5 cv=4,----> 6 verbose=2)
~\Anaconda3\envs\Tensorflow\lib\site-packages\sklearn\utils\validation.py in inner_f(*args,**kwargs)
70 FutureWarning)
71 kwargs.update({k: arg for k,arg in zip(sig.parameters,args)})
---> 72 return f(**kwargs)
73 return inner_f
74
~\Anaconda3\envs\Tensorflow\lib\site-packages\sklearn\model_selection\_validation.py in cross_val_predict(estimator,X,y,groups,cv,n_jobs,verbose,fit_params,pre_dispatch,method)
771 prediction_blocks = parallel(delayed(_fit_and_predict)(
772 clone(estimator),train,test,method)
--> 773 for train,test in cv.split(X,groups))
774
775 # Concatenate the predictions
~\Anaconda3\envs\Tensorflow\lib\site-packages\joblib\parallel.py in __call__(self,iterable)
919 # remaining jobs.
920 self._iterating = False
--> 921 if self.dispatch_one_batch(iterator):
922 self._iterating = self._original_iterator is not None
923
~\Anaconda3\envs\Tensorflow\lib\site-packages\joblib\parallel.py in dispatch_one_batch(self,iterator)
752 tasks = BatchedCalls(itertools.islice(iterator,batch_size),753 self._backend.get_nested_backend(),--> 754 self._pickle_cache)
755 if len(tasks) == 0:
756 # No more tasks available in the iterator: tell caller to stop.
~\Anaconda3\envs\Tensorflow\lib\site-packages\joblib\parallel.py in __init__(self,iterator_slice,backend_and_jobs,pickle_cache)
208
209 def __init__(self,pickle_cache=None):
--> 210 self.items = list(iterator_slice)
211 self._size = len(self.items)
212 if isinstance(backend_and_jobs,tuple):
~\Anaconda3\envs\Tensorflow\lib\site-packages\sklearn\model_selection\_validation.py in <genexpr>(.0)
771 prediction_blocks = parallel(delayed(_fit_and_predict)(
772 clone(estimator),groups))
774
775 # Concatenate the predictions
~\Anaconda3\envs\Tensorflow\lib\site-packages\sklearn\utils\validation.py in inner_f(*args,args)})
---> 72 return f(**kwargs)
73 return inner_f
74
~\Anaconda3\envs\Tensorflow\lib\site-packages\sklearn\base.py in clone(estimator,safe)
85 new_object_params = estimator.get_params(deep=False)
86 for name,param in new_object_params.items():
---> 87 new_object_params[name] = clone(param,safe=False)
88 new_object = klass(**new_object_params)
89 params_set = new_object.get_params(deep=False)
~\Anaconda3\envs\Tensorflow\lib\site-packages\sklearn\utils\validation.py in inner_f(*args,safe)
69 elif not hasattr(estimator,'get_params') or isinstance(estimator,type):
70 if not safe:
---> 71 return copy.deepcopy(estimator)
72 else:
73 if isinstance(estimator,type):
~\Anaconda3\envs\Tensorflow\lib\copy.py in deepcopy(x,memo,_nil)
178 y = x
179 else:
--> 180 y = _reconstruct(x,*rv)
181
182 # If is its own copy,don't memoize.
~\Anaconda3\envs\Tensorflow\lib\copy.py in _reconstruct(x,func,args,state,listiter,dictiter,deepcopy)
278 if state is not None:
279 if deep:
--> 280 state = deepcopy(state,memo)
281 if hasattr(y,'__setstate__'):
282 y.__setstate__(state)
~\Anaconda3\envs\Tensorflow\lib\copy.py in deepcopy(x,_nil)
148 copier = _deepcopy_dispatch.get(cls)
149 if copier:
--> 150 y = copier(x,memo)
151 else:
152 try:
~\Anaconda3\envs\Tensorflow\lib\copy.py in _deepcopy_dict(x,deepcopy)
238 memo[id(x)] = y
239 for key,value in x.items():
--> 240 y[deepcopy(key,memo)] = deepcopy(value,memo)
241 return y
242 d[dict] = _deepcopy_dict
~\Anaconda3\envs\Tensorflow\lib\copy.py in deepcopy(x,memo)
151 else:
152 try:
~\Anaconda3\envs\Tensorflow\lib\copy.py in _deepcopy_list(x,deepcopy)
213 append = y.append
214 for a in x:
--> 215 append(deepcopy(a,memo))
216 return y
217 d[list] = _deepcopy_list
~\Anaconda3\envs\Tensorflow\lib\copy.py in deepcopy(x,_nil)
167 reductor = getattr(x,"__reduce_ex__",None)
168 if reductor:
--> 169 rv = reductor(4)
170 else:
171 reductor = getattr(x,"__reduce__",None)
TypeError: can't pickle _thread.RLock objects
当使用普通的tensorflow .fit执行时,弹出以下错误。也许这与第一个问题有关。
Train on 10420 samples,validate on 1697 samples
Epoch 1/8
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-30-3f5256ff03ec> in <module>
----> 1 Test_tdws=twds_model.fit(X_train,y_train,epochs=8,verbose=2,validation_split=(0.14),shuffle=False) #callbacks=[tensorboard])
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self,x,batch_size,epochs,callbacks,validation_split,validation_data,shuffle,class_weight,sample_weight,initial_epoch,steps_per_epoch,validation_steps,max_queue_size,workers,use_multiprocessing,**kwargs)
878 initial_epoch=initial_epoch,879 steps_per_epoch=steps_per_epoch,--> 880 validation_steps=validation_steps)
881
882 def evaluate(self,~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training_arrays.py in model_iteration(model,inputs,targets,sample_weights,val_inputs,val_targets,val_sample_weights,mode,validation_in_fit,**kwargs)
327
328 # Get outputs.
--> 329 batch_outs = f(ins_batch)
330 if not isinstance(batch_outs,list):
331 batch_outs = [batch_outs]
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\backend.py in __call__(self,inputs)
3074
3075 fetched = self._callable_fn(*array_vals,-> 3076 run_metadata=self.run_metadata)
3077 self._call_fetch_callbacks(fetched[-len(self._fetches):])
3078 return nest.pack_sequence_as(self._outputs_structure,~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\client\session.py in __call__(self,*args,**kwargs)
1437 ret = tf_session.TF_SessionRunCallable(
1438 self._session._session,self._handle,status,-> 1439 run_metadata_ptr)
1440 if run_metadata:
1441 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\framework\errors_impl.py in __exit__(self,type_arg,value_arg,traceback_arg)
526 None,None,527 compat.as_text(c_api.TF_Message(self.status.status)),--> 528 c_api.TF_GetCode(self.status.status))
529 # Delete the underlying status object from memory otherwise it stays alive
530 # as there is a reference to status from this from the traceback due to
InvalidArgumentError: Incompatible shapes: [144,1] vs. [144,1]
[[{{node loss_2/dense_4_loss/sub}}]]
[[{{node loss_2/mul}}]]
使用的版本是:
Package Version
---------------------- --------------------
- nsorflow-gpu
-ensorflow-gpu 1.13.1
-rotobuf 3.11.3
-umpy 1.18.1
absl-py 0.9.0
antlr4-python3-runtime 4.8
asn1crypto 1.3.0
astor 0.7.1
astropy 3.2.1
astunparse 1.6.3
attrs 19.3.0
audioread 2.1.8
autopep8 1.5.3
backcall 0.1.0
beautifulsoup4 4.9.0
bezier 0.8.0
bkcharts 0.2
bleach 3.1.4
blis 0.2.4
bokeh 1.1.0
boto3 1.9.253
botocore 1.12.253
Bottleneck 1.3.2
cachetools 4.1.0
certifi 2020.4.5.1
cffi 1.14.0
chardet 3.0.4
click 6.7
cloudpickle 0.5.3
cmdstanpy 0.4.0
color 0.1
colorama 0.4.3
colorcet 0.9.1
convertdate 2.2.1
copulas 0.2.5
cryptography 2.8
ctgan 0.2.1
cycler 0.10.0
cymem 2.0.2
Cython 0.29.17
dash 0.26.0
dash-core-components 0.27.2
dash-html-components 0.11.0
dash-renderer 0.13.2
dask 0.18.1
dataclasses 0.6
datashader 0.7.0
datashape 0.5.2
datawig 0.1.10
deap 1.3.0
decorator 4.4.2
defusedxml 0.6.0
deltapy 0.1.1
dill 0.2.9
distributed 1.22.1
docutils 0.14
entrypoints 0.3
ephem 3.7.7.1
et-xmlfile 1.0.1
exrex 0.10.5
Faker 4.0.3
fastai 1.0.60
fastprogress 0.2.2
fbprophet 0.6
fire 0.3.1
Flask 1.0.2
Flask-Compress 1.4.0
future 0.17.1
gast 0.3.3
geojson 2.4.1
geomet 0.2.0.post2
google-auth 1.14.0
google-auth-oauthlib 0.4.1
google-pasta 0.2.0
gplearn 0.4.1
graphviz 0.13.2
grpcio 1.29.0
h5py 2.10.0
HeapDict 1.0.0
holidays 0.10.2
holoviews 1.12.1
html2text 2018.1.9
hyperas 0.4.1
hyperopt 0.1.2
idna 2.6
imageio 2.5.0
imbalanced-learn 0.3.3
imblearn 0.0
importlib-metadata 1.5.0
impyute 0.0.8
ipykernel 5.1.4
ipython 7.13.0
ipython-genutils 0.2.0
ipywidgets 7.5.1
itsdangerous 0.24
jdcal 1.4
jedi 0.16.0
Jinja2 2.11.1
jmespath 0.9.5
joblib 0.13.2
jsonschema 3.2.0
jupyter 1.0.0
jupyter-client 6.1.2
jupyter-console 6.0.0
jupyter-core 4.6.3
Keras 2.2.5
Keras-Applications 1.0.8
Keras-Preprocessing 1.1.2
keras-rectified-adam 0.17.0
kiwisolver 1.2.0
korean-lunar-calendar 0.2.1
librosa 0.7.2
llvmlite 0.32.1
lml 0.0.1
locket 0.2.0
LunarCalendar 0.0.9
Markdown 2.6.11
MarkupSafe 1.1.1
matplotlib 3.2.1
missingpy 0.2.0
mistune 0.8.4
mkl-fft 1.0.15
mkl-random 1.1.0
mkl-service 2.3.0
mock 4.0.2
msgpack 0.5.6
multipledispatch 0.6.0
murmurhash 1.0.2
mxnet 1.4.1
nb-conda 2.2.1
nb-conda-kernels 2.2.3
nbconvert 5.6.1
nbformat 5.0.4
nbstripout 0.3.7
networkx 2.1
notebook 6.0.3
numba 0.49.1
numexpr 2.7.1
numpy 1.19.0
oauthlib 3.1.0
olefile 0.46
opencv-python 4.2.0.34
openpyxl 2.5.5
opt-einsum 3.2.1
packaging 20.3
pandas 1.0.3
pandasvault 0.0.3
pandocfilters 1.4.2
param 1.9.0
parso 0.6.2
partd 0.3.8
patsy 0.5.1
pbr 5.1.3
pickleshare 0.7.5
Pillow 7.0.0
pip 20.0.2
plac 0.9.6
plotly 4.7.1
plotly-express 0.4.1
preshed 2.0.1
prometheus-client 0.7.1
prompt-toolkit 3.0.4
protobuf 3.11.3
psutil 5.4.7
py 1.8.0
pyasn1 0.4.8
pyasn1-modules 0.2.8
pycodestyle 2.6.0
pycparser 2.20
pyct 0.4.5
pyensae 1.3.839
pyexcel 0.5.8
pyexcel-io 0.5.7
Pygments 2.6.1
pykalman 0.9.5
PyMeeus 0.3.7
pymongo 3.8.0
pyOpenSSL 19.1.0
pyparsing 2.4.7
pypi 2.1
pyquickhelper 1.9.3418
pyrsistent 0.16.0
PySocks 1.7.1
pystan 2.19.1.1
python-dateutil 2.8.1
pytz 2019.3
pyviz-comms 0.7.2
PyWavelets 0.5.2
pywin32 227
pywinpty 0.5.7
PyYAML 5.3.1
pyzmq 18.1.1
qtconsole 4.4.4
rdt 0.2.1
RegscorePy 1.1
requests 2.23.0
requests-oauthlib 1.3.0
resampy 0.2.2
retrying 1.3.3
rsa 4.0
s3transfer 0.2.1
scikit-image 0.15.0
scikit-learn 0.23.2
scipy 1.4.1
sdv 0.3.2
seaborn 0.9.0
seasonal 0.3.1
Send2Trash 1.5.0
sentinelsat 0.12.2
setuptools 46.3.0
setuptools-git 1.2
six 1.14.0
sklearn 0.0
sortedcontainers 2.0.4
SoundFile 0.10.3.post1
soupsieve 2.0
spacy 2.1.8
srsly 0.1.0
statsmodels 0.9.0
stopit 1.1.2
sugartensor 1.0.0.2
ta 0.5.25
tb-nightly 1.14.0a20190603
tblib 1.3.2
tensorboard 1.13.1
tensorboard-plugin-wit 1.6.0.post3
tensorflow-estimator 1.13.0
tensorflow-gpu 1.13.1
termcolor 1.1.0
terminado 0.8.3
testpath 0.4.4
text-unidecode 1.3
texttable 1.4.0
tf-estimator-nightly 1.14.0.dev2019060501
Theano 1.0.4
thinc 7.0.8
threadpoolctl 2.1.0
toml 0.10.1
toolz 0.10.0
torch 1.4.0
torchvision 0.5.0
tornado 6.0.4
TPOT 0.10.2
tqdm 4.45.0
traitlets 4.3.3
transforms3d 0.3.1
tsaug 0.2.1
typeguard 2.7.1
typing 3.6.6
update-checker 0.16
urllib3 1.22
utm 0.4.2
wasabi 0.2.2
wcwidth 0.1.9
webencodings 0.5.1
Werkzeug 1.0.1
wheel 0.34.2
widgetsnbextension 3.5.1
win-inet-pton 1.1.0
wincertstore 0.2
wrapt 1.11.2
xarray 0.10.8
xlrd 1.1.0
yahoo-historical 0.3.2
zict 0.1.3
zipp 2.2.0
对于每一个指向正在运行的代码的提示,都表示非常感谢;-)!
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)