我在 CNN 图像增强中有一些代码错误

问题描述

CNN 运行前的增强过程出现错误

这是带有 MNIST 数据的代码

#import packages    
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import to_categorical
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy import ndimage


learning_rate = 0.001
training_epochs = 15
batch_size = 100

#check point
cur_dir = os.getcwd()
ckpt_dir_name = 'checkpoints'
model_dir_name = 'minst_cnn_best'

checkpoint_dir = os.path.join(cur_dir,ckpt_dir_name,model_dir_name)
os.makedirs(checkpoint_dir,exist_ok=True)

checkpoint_prefix = os.path.join(checkpoint_dir,model_dir_name)

#dataset

## MNIST Dataset #########################################################
mnist = keras.datasets.mnist
class_names = ['0','1','2','3','4','5','6','7','8','9']
##########################################################################

## Fashion MNIST Dataset #################################################
#mnist = keras.datasets.fashion_mnist
#class_names = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']
##########################################################################

以下代码错误的主要来源。定义data_augmentation函数后,出现一些错误

def data_augmentation(images,labels):
    aug_images = []
    aug_labels = []    
    
    for x,y in zip(images,labels):        
        aug_images.append(x)
        aug_labels.append(y)
        
        bg_value = np.median(x)
        for _ in range(4):
            angle = np.random.randint(-15,15,1)
            rot_img = ndimage.rotate(x,angle,reshape=False,cval=bg_value)
            
            shift = np.random.randint(-2,2,2)
            shift_img = ndimage.shift(rot_img,shift,cval=bg_value)            
            
            aug_images.append(shift_img)
            aug_labels.append(y)
    aug_images = np.array(aug_images)
    aug_labels = np.array(aug_labels)
    return aug_images,aug_labels

数据增强

train_images,train_labels = data_augmentation(train_images,train_labels)

train_images = train_images.astype(np.float32) / 255.
test_images = test_images.astype(np.float32) / 255.
train_images = np.expand_dims(train_images,axis=-1)
test_images = np.expand_dims(test_images,axis=-1)


train_labels = to_categorical(train_labels,10)
test_labels = to_categorical(test_labels,10)    
    
train_dataset = tf.data.Dataset.from_tensor_slices((train_images,train_labels)).shuffle(
                buffer_size=500000).batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images,test_labels)).batch(batch_size)

这是我的错误代码。我认为这是一种矩阵维数问题,但也不知道。 谢谢。

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-40-09317cf7f945> in <module>()
----> 1 train_images,train_labels)
      2 
      3 train_images = train_images.astype(np.float32) / 255.
      4 test_images = test_images.astype(np.float32) / 255.
      5 train_images = np.expand_dims(train_images,axis=-1)

1 frames
/usr/local/lib/python3.6/dist-packages/scipy/ndimage/interpolation.py in rotate(input,axes,reshape,output,order,mode,cval,prefilter)
    716         out_plane_shape = img_shape[axes]
    717 
--> 718     out_center = rot_matrix @ ((out_plane_shape - 1) / 2)
    719     in_center = (in_plane_shape - 1) / 2
    720     offset = in_center - out_center

ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0,with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 1)

解决方法

angle = np.random.randint(-15,15,1)
shift = np.random.randint(-2,2,2)

这两行将分别生成一个形状为 (1,) 和 (2,) 的数组。我想你的意思是:

    angle = np.random.randint(-15,15)
    shift = np.random.randint(-2,2)

因为您希望它们是整数。