问题描述
我想在阅读TfRecords的同时进行一些扩充,我一直在使用tf.numpy_function,然后将其包装到tf.function,但是我的训练非常慢。
我如何加快cupy的过程?
例如,我使用affine_transform form scipy,但我发现here在cupy中也有类似的功能。
这是我在tensorflow中读取TfRecords时在map函数中使用的函数。
@tf.function
def rotation3D(volume,label):
def scipy_rotate(volume):
# tf.config.optimizer.set_experimental_options('loop_optimization')
alpha,beta,gamma = np.random.randint(0,31,size=3)/180*np.pi
Rx = np.array([[1,0],[0,np.cos(alpha),-np.sin(alpha)],np.sin(alpha),np.cos(alpha)]])
Ry = np.array([[np.cos(beta),np.sin(beta)],1,[-np.sin(beta),np.cos(beta)]])
Rz = np.array([[np.cos(gamma),-np.sin(gamma),[np.sin(gamma),np.cos(gamma),1]])
R = np.dot(np.dot(Rx,Ry),Rz)
volume_rot = np.empty_like(volume)
for channel in tf.range(volume.shape[-1]):
volume_rot[:,:,channel] = affine_transform(volume[:,channel],R,offset=0,order=3,mode='nearest')
return volume_rot
augmented_volume = tf.numpy_function(scipy_rotate,[volume],tf.float32)
return augmented_volume,label
这是我用来读取TfRecords和进行扩充的功能。
def input_fn(filenames,subset,batch_size,buffer_size=512,data_augmentation=True):
# Args:
# filenames: Filenames for the TFRecords files.
# subset: Subset to make either train,valid,test.
# batch_size: Return batches of this size.
# buffer_size: Read buffers of this size. The random shuffling
# is done on the buffer,so it must be big enough.
# Create a TensorFlow Dataset-object which has functionality
# for reading and shuffling data from TFRecords files.
AUTO = tf.data.experimental.AUTOTUNE
dataset = tf.data.TFRecordDataset(filenames=filenames)
# Parse the serialized data in the TFRecords files.
# This returns TensorFlow tensors for the image and labels.
dataset = dataset.map(parse_example,num_parallel_calls = AUTO)
# make the training dataset to iterate forever
if subset == 'train' or subset == 'valid':
dataset = dataset.repeat()
# shuffle the training dataset
if subset != 'test':
dataset = dataset.shuffle(buffer_size=buffer_size)
if (subset != 'test' and data_augmentation == True):
# dataset = dataset.map(elastic3D,num_parallel_calls = AUTO)
# dataset = dataset.map(flip3D,num_parallel_calls = AUTO)
dataset = dataset.map(rotation3D,num_parallel_calls = AUTO)
# dataset = dataset.map(blur3D,num_parallel_calls = AUTO)
# set bach_size
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.prefetch(2)
return dataset
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)