问题描述
CNN 运行前的增强过程出现错误。
这是带有 MNIST 数据的代码。
#import packages
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import to_categorical
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy import ndimage
learning_rate = 0.001
training_epochs = 15
batch_size = 100
#check point
cur_dir = os.getcwd()
ckpt_dir_name = 'checkpoints'
model_dir_name = 'minst_cnn_best'
checkpoint_dir = os.path.join(cur_dir,ckpt_dir_name,model_dir_name)
os.makedirs(checkpoint_dir,exist_ok=True)
checkpoint_prefix = os.path.join(checkpoint_dir,model_dir_name)
#dataset
## MNIST Dataset #########################################################
mnist = keras.datasets.mnist
class_names = ['0','1','2','3','4','5','6','7','8','9']
##########################################################################
## Fashion MNIST Dataset #################################################
#mnist = keras.datasets.fashion_mnist
#class_names = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']
##########################################################################
以下代码是错误的主要来源。定义data_augmentation函数后,出现一些错误。
def data_augmentation(images,labels):
aug_images = []
aug_labels = []
for x,y in zip(images,labels):
aug_images.append(x)
aug_labels.append(y)
bg_value = np.median(x)
for _ in range(4):
angle = np.random.randint(-15,15,1)
rot_img = ndimage.rotate(x,angle,reshape=False,cval=bg_value)
shift = np.random.randint(-2,2,2)
shift_img = ndimage.shift(rot_img,shift,cval=bg_value)
aug_images.append(shift_img)
aug_labels.append(y)
aug_images = np.array(aug_images)
aug_labels = np.array(aug_labels)
return aug_images,aug_labels
数据增强
train_images,train_labels = data_augmentation(train_images,train_labels)
train_images = train_images.astype(np.float32) / 255.
test_images = test_images.astype(np.float32) / 255.
train_images = np.expand_dims(train_images,axis=-1)
test_images = np.expand_dims(test_images,axis=-1)
train_labels = to_categorical(train_labels,10)
test_labels = to_categorical(test_labels,10)
train_dataset = tf.data.Dataset.from_tensor_slices((train_images,train_labels)).shuffle(
buffer_size=500000).batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images,test_labels)).batch(batch_size)
这是我的错误代码。我认为这是一种矩阵维数问题,但也不知道。 谢谢。
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-40-09317cf7f945> in <module>()
----> 1 train_images,train_labels)
2
3 train_images = train_images.astype(np.float32) / 255.
4 test_images = test_images.astype(np.float32) / 255.
5 train_images = np.expand_dims(train_images,axis=-1)
1 frames
/usr/local/lib/python3.6/dist-packages/scipy/ndimage/interpolation.py in rotate(input,axes,reshape,output,order,mode,cval,prefilter)
716 out_plane_shape = img_shape[axes]
717
--> 718 out_center = rot_matrix @ ((out_plane_shape - 1) / 2)
719 in_center = (in_plane_shape - 1) / 2
720 offset = in_center - out_center
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0,with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 1)
解决方法
angle = np.random.randint(-15,15,1)
shift = np.random.randint(-2,2,2)
这两行将分别生成一个形状为 (1,) 和 (2,) 的数组。我想你的意思是:
angle = np.random.randint(-15,15)
shift = np.random.randint(-2,2)
因为您希望它们是整数。