实现 autokeras 时间序列模型时出错

问题描述

我试图在串行数据集上实现 autokeras TimeSeriesForecaster。下面分别给出数据集的特征和标签

df1_x =

enter image description here

df1_y = 
0    2.5
1    2.1
2    2.2
3    2.2
4    1.5
Name: target_carbon_monoxide,dtype: float64

AutoML 准备

#parameters
predict_from = 1
predict_until = 1
lookback = 3
clf = ak.TimeseriesForecaster(
    lookback=lookback,predict_from=predict_from,predict_until=predict_until,max_trials=1,objective="val_loss",)
# Train the TimeSeriesForecaster with train data
clf.fit(
    x=df1_x,y=df1_y,epochs=10,)

数据框没有 NaN 值,特征数据框的形状是 (7111,8),即二维数据框。

错误如下:

Search: Running Trial #1

Hyperparameter    |Value             |Best Value So Far 
timeseries_bloc...|True              |?                 
timeseries_bloc...|lstm              |?                 
timeseries_bloc...|3                 |?                 
regression_head...|0                 |?                 
optimizer         |adam              |?                 
learning_rate     |0.001             |?                 

Epoch 1/10
    173/UnkNown - 4s 5ms/step - loss: 2.2421 - mean_squared_error: 2.2421
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
/tmp/ipykernel_11292/1163792963.py in <module>
     10 )
     11 # Train the TimeSeriesForecaster with train data
---> 12 clf.fit(
     13     x=df1_x,14     y=df1_y,InvalidArgumentError:  Incompatible shapes: [32,1] vs. [30,1]
     [[node mean_squared_error/SquaredDifference (defined at home/samar/.local/lib/python3.8/site-packages/autokeras/utils/utils.py:88) ]] [Op:__inference_train_function_13895]

Function call stack:
train_function

解决方法

您需要向 fit() 提供验证数据。如果将您拥有的数据 (df1) 拆分为训练集和验证,并将它们都提供给 fit(),则训练将运行良好。尝试使用 train_test_split 拆分数据,或者您可以手动进行。 您的代码将是这样的:

from sklearn.model_selection import train_test_split
df1_x,df1_x_eval,df1_y,df1_y_eval = train_test_split(df1_x,test_size=0.25,random_state=42)
clf.fit(
    x=df1_x,y=df1_y,validation_data = (df_x_eval,df_y_eval),epochs=10)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...