问题描述
我正在尝试使用来自二维np数组的一些自定义数据集在张量流中创建基本的CNN。 我似乎无法使输入数据与卷积层的input_shape或batch_input_shape参数对齐。我已经尝试了变量的所有顺序,并且尝试了与文档相同的方式,但是不确定为什么它仍然会产生错误。
任何帮助将不胜感激!
import os
import pickle
import pandas as pd
import matplotlib as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras import models,datasets,layers
BATCH_SIZE = 4
TRAIN_SPLIT = 0.8
VAL_SPLIT = 0.1
TEST_SPLIT = 0.1
with open((CWD+'/CLNY_X.npy'),mode='rb') as f:
Xt = np.load(f,allow_pickle=True)
with open((CWD+'/CLNY_Y.npy'),mode='rb') as f:
Y = np.load(f,allow_pickle=True)
X = Xt.reshape(Xt.shape + (1,))
DATASIZE = Y.shape[0]
print("Datasize: ",DATASIZE)
Datasize: 172
# test out with different period moving averages,so we take the
dataset = tf.data.Dataset.from_tensor_slices((X,Y))
for feat,targ in dataset.take(1):
print('NRows: {},NCols: {},Target: {}\nFeat: {}'.format(len(feat),len(feat[0]),targ,feat))
NRows: 10000,NCols: 10,Target: 0.2587999999523163
Feat: [[[5.0292000e+01]
[1.5998565e-01]
[7.5094378e-01]
...
[1.0000000e+00]
[2.5231593e-05]
[1.4535466e-01]]
[[5.0492001e+01]
[2.9965147e-01]
[1.4065099e+00]
...
[1.8729897e+00]
[4.7258512e-05]
[2.7224776e-01]]
[[5.0692001e+01]
[2.9965451e-01]
[1.4065243e+00]
...
[1.8730087e+00]
[4.7258993e-05]
[2.7225053e-01]]
...
[[0.0000000e+00]
[0.0000000e+00]
[0.0000000e+00]
...
[0.0000000e+00]
[0.0000000e+00]
[0.0000000e+00]]
[[0.0000000e+00]
[0.0000000e+00]
[0.0000000e+00]
...
[0.0000000e+00]
[0.0000000e+00]
[0.0000000e+00]]
[[0.0000000e+00]
[0.0000000e+00]
[0.0000000e+00]
...
[0.0000000e+00]
[0.0000000e+00]
[0.0000000e+00]]]
train_size = int(DATASIZE*TRAIN_SPLIT)
val_size = int(DATASIZE*VAL_SPLIT)
test_size = int(DATASIZE*TEST_SPLIT)
dataset = dataset.shuffle(DATASIZE)
train_dataset = dataset.take(train_size).batch(BATCH_SIZE)
test_dataset = dataset.skip(train_size)
val_dataset = dataset.skip(test_size)
test_dataset = dataset.take(test_size)
CONVERTED_LENGTH = 10000
CONVERTED_WIDTH = 10
model = models.Sequential()
#model.add(layers.Conv1D(32,kernel_size=(10),activation='relu',data_format='channels_last',batch_input_shape=(CONVERTED_LENGTH,CONVERTED_WIDTH,1)))
model.add(layers.Conv2D(32,kernel_size=(2,2),BATCH_SIZE,1)))
model.add(layers.Flatten())
model.add(layers.Dense(32,activation='relu'))
model.add(layers.Dense(1,activation='softmax'))
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (10000,9,3,32) 160
_________________________________________________________________
flatten (Flatten) (10000,864) 0
_________________________________________________________________
dense (Dense) (10000,32) 27680
_________________________________________________________________
dense_1 (Dense) (10000,1) 33
=================================================================
Total params: 27,873
Trainable params: 27,873
Non-trainable params: 0
_________________________________________________________________
model.compile(optimizer='adam',loss=tf.keras.losses.MeanSquaredError(),metrics=['accuracy'])
history = model.fit(train_dataset,epochs=10,validation_data=(val_dataset)) # add the validation_data=(test_data,test_targets)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-9-c0e1d31b7f23> in <module>
3 metrics=['accuracy'])
4
----> 5 history = model.fit(train_dataset,test_targets)
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\keras\engine\training.py in fit(self,x,y,batch_size,epochs,verbose,callbacks,validation_split,validation_data,shuffle,class_weight,sample_weight,initial_epoch,steps_per_epoch,validation_steps,validation_freq,max_queue_size,workers,use_multiprocessing,**kwargs)
817 max_queue_size=max_queue_size,818 workers=workers,--> 819 use_multiprocessing=use_multiprocessing)
820
821 def evaluate(self,C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py in fit(self,model,**kwargs)
233 max_queue_size=max_queue_size,234 workers=workers,--> 235 use_multiprocessing=use_multiprocessing)
236
237 total_samples = _get_total_number_of_samples(training_data_adapter)
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py in _process_training_inputs(model,sample_weights,class_weights,distribution_strategy,use_multiprocessing)
591 max_queue_size=max_queue_size,592 workers=workers,--> 593 use_multiprocessing=use_multiprocessing)
594 val_adapter = None
595 if validation_data:
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py in _process_inputs(model,mode,steps,use_multiprocessing)
704 max_queue_size=max_queue_size,705 workers=workers,--> 706 use_multiprocessing=use_multiprocessing)
707
708 return adapter
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\keras\engine\data_adapter.py in __init__(self,standardize_function,**kwargs)
700
701 if standardize_function is not None:
--> 702 x = standardize_function(x)
703
704 # Note that the dataset instance is immutable,its fine to reusing the user
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py in standardize_function(dataset)
682 return x,y
683 return x,sample_weights
--> 684 return dataset.map(map_fn,num_parallel_calls=dataset_ops.AUTOTUNE)
685
686 if mode == ModeKeys.PREDICT:
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\data\ops\dataset_ops.py in map(self,map_func,num_parallel_calls)
1589 else:
1590 return ParallelMapDataset(
-> 1591 self,num_parallel_calls,preserve_cardinality=True)
1592
1593 def flat_map(self,map_func):
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\data\ops\dataset_ops.py in __init__(self,input_dataset,use_inter_op_parallelism,preserve_cardinality,use_legacy_function)
3924 self._transformation_name(),3925 dataset=input_dataset,-> 3926 use_legacy_function=use_legacy_function)
3927 self._num_parallel_calls = ops.convert_to_tensor(
3928 num_parallel_calls,dtype=dtypes.int32,name="num_parallel_calls")
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\data\ops\dataset_ops.py in __init__(self,func,transformation_name,dataset,input_classes,input_shapes,input_types,input_structure,add_to_graph,use_legacy_function,defun_kwargs)
3145 with tracking.resource_tracker_scope(resource_tracker):
3146 # TODO(b/141462134): Switch to using garbage collection.
-> 3147 self._function = wrapper_fn._get_concrete_function_internal()
3148
3149 if add_to_graph:
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\eager\function.py in _get_concrete_function_internal(self,*args,**kwargs)
2393 """Bypasses error checking when getting a graph function."""
2394 graph_function = self._get_concrete_function_internal_garbage_collected(
-> 2395 *args,**kwargs)
2396 # We're returning this concrete function to someone,and they may keep a
2397 # reference to the FuncGraph without keeping a reference to the
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\eager\function.py in _get_concrete_function_internal_garbage_collected(self,**kwargs)
2387 args,kwargs = None,None
2388 with self._lock:
-> 2389 graph_function,_,_ = self._maybe_define_function(args,kwargs)
2390 return graph_function
2391
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\eager\function.py in _maybe_define_function(self,args,kwargs)
2701
2702 self._function_cache.missed.add(call_context_key)
-> 2703 graph_function = self._create_graph_function(args,kwargs)
2704 self._function_cache.primary[cache_key] = graph_function
2705 return graph_function,kwargs
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\eager\function.py in _create_graph_function(self,kwargs,override_flat_arg_shapes)
2591 arg_names=arg_names,2592 override_flat_arg_shapes=override_flat_arg_shapes,-> 2593 capture_by_value=self._capture_by_value),2594 self._function_attributes,2595 # Tell the ConcreteFunction to clean up its graph once it goes out of
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\framework\func_graph.py in func_graph_from_py_func(name,python_func,signature,func_graph,autograph,autograph_options,add_control_dependencies,arg_names,op_return_value,collections,capture_by_value,override_flat_arg_shapes)
976 converted_func)
977
--> 978 func_outputs = python_func(*func_args,**func_kwargs)
979
980 # invariant: `func_outputs` contains only Tensors,CompositeTensors,C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\data\ops\dataset_ops.py in wrapper_fn(*args)
3138 attributes=defun_kwargs)
3139 def wrapper_fn(*args): # pylint: disable=missing-docstring
-> 3140 ret = _wrapper_helper(*args)
3141 ret = structure.to_tensor_list(self._output_structure,ret)
3142 return [ops.convert_to_tensor(t) for t in ret]
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\data\ops\dataset_ops.py in _wrapper_helper(*args)
3080 nested_args = (nested_args,)
3081
-> 3082 ret = autograph.tf_convert(func,ag_ctx)(*nested_args)
3083 # If `func` returns a list of tensors,`nest.flatten()` and
3084 # `ops.convert_to_tensor()` would conspire to attempt to stack
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\autograph\impl\api.py in wrapper(*args,**kwargs)
235 except Exception as e: # pylint:disable=broad-except
236 if hasattr(e,'ag_error_metadata'):
--> 237 raise e.ag_error_metadata.to_exception(e)
238 else:
239 raise
ValueError: in converted code:
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py:677 map_fn
batch_size=None)
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\keras\engine\training.py:2410 _standardize_tensors
exception_prefix='input')
C:\ProgramData\Miniconda3\lib\site-packages\tensorflow_core\python\keras\engine\training_utils.py:582 standardize_input_data
str(data_shape))
ValueError: Error when checking input: expected conv2d_input to have shape (10,4,1) but got array with shape (10000,10,1)
总是说输入数据不是预期的格式,或者ndims错误,因为它将某些值添加为None。我就是无法运行!
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)