问题描述
AttributeError: 'str' object has no attribute 'keys'
这是主要代码:
def generate_arrays_for_training(indexPat,paths,start=0,end=100):
while True:
from_=int(len(paths)/100*start)
to_=int(len(paths)/100*end)
for i in range(from_,int(to_)):
f=paths[i]
x = np.load(PathSpectogramFolder+f)
if('P' in f):
y = np.repeat([[0,1]],x.shape[0],axis=0)
else:
y =np.repeat([[1,0]],axis=0)
yield(x,y)
history=model.fit_generator(generate_arrays_for_training(indexPat,filesPath,end=75) ## problem here
steps_per_epoch=int((len(filesPath)-int(len(filesPath)/100*25))),validation_steps=int((len(filesPath)-int(len(filesPath)/100*75))),verbose=2,class_weight="balanced",epochs=15,max_queue_size=2,shuffle=True,callbacks=[callback])
其中generate_arrays_for_training
函数返回x
和y
。 x
是浮点数的2D数组,而y
是[0,1]。
错误:
Traceback (most recent call last):
File "/home/user1/thesis2/CNN_dwt2.py",line 437,in <module>
main()
File "/home/user1/thesis2/CNN_dwt2.py",line 316,in main
history=model.fit_generator(generate_arrays_for_training(indexPat,end=75),File "/home/user1/.local/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py",line 324,in new_func
return func(*args,**kwargs)
File "/home/user1/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py",line 1815,in fit_generator
return self.fit(
File "/home/user1/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py",line 108,in _method_wrapper
return method(self,*args,line 1049,in fit
data_handler = data_adapter.DataHandler(
File "/home/user1/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/data_adapter.py",line 1122,in __init__
dataset = dataset.map(_make_class_weight_map_fn(class_weight))
File "/home/user1/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/data_adapter.py",line 1295,in _make_class_weight_map_fn
class_ids = list(sorted(class_weight.keys()))
AttributeError: 'str' object has no attribute 'keys'
解决方法
您的问题是由您传递给class_weight="balanced"
的{{1}}参数引起的
根据model.fit() reference,此参数应为dict: 可选的字典映射类索引(整数)到权重(浮点)值,用于加权损失函数(仅在训练过程中)。这可能有助于告诉模型“更多关注”来自代表性不足的类的样本。
尝试model.fit()
进行测试,它应该摆脱原始错误。稍后提供适当的字典作为class_weight=None
,以解决数据集不平衡的问题。